From d5ab76cdfbc4bcd3eefca1ae5a55a719f45d8cc5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 31 Jan 2023 16:47:06 +0100 Subject: [PATCH] use Rust type system to validate logic --- proto/generate.proto | 2 +- router/src/db.rs | 38 ++++--------------------- router/src/infer.rs | 5 ++-- router/src/validation.rs | 60 ++++++++++++++++++++++++++++++++-------- 4 files changed, 58 insertions(+), 47 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 8f431c5c..11455fee 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -31,7 +31,7 @@ message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; /// restricting to the k highest probability elements - uint32 top_k = 2; + int32 top_k = 2; /// restricting to top tokens summing to prob_cut_off <= prob_cut_off float top_p = 3; /// apply sampling on the logits diff --git a/router/src/db.rs b/router/src/db.rs index 4959fc47..246e4d5d 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,14 +1,12 @@ /// This code is massively inspired by Tokio mini-redis use crate::infer::InferError; use crate::infer::InferStreamResponse; -use crate::{GenerateParameters, GenerateRequest}; +use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; -use text_generation_client::{ - Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters, -}; +use text_generation_client::{Batch, Request}; use tokio::sync::mpsc::UnboundedSender; use tokio::sync::OwnedSemaphorePermit; use tokio::time::Instant; @@ -17,11 +15,9 @@ use tokio::time::Instant; #[derive(Debug)] pub(crate) struct Entry { /// Request - pub request: GenerateRequest, + pub request: ValidGenerateRequest, /// Response sender to communicate between the Infer struct and the batching_task pub response_tx: UnboundedSender>, - /// Number of tokens in the input - pub input_length: usize, /// Instant when this entry was created pub time: Instant, /// Instant when this entry was added to a batch @@ -75,9 +71,9 @@ impl State { requests.push(Request { id: *id, inputs: entry.request.inputs.clone(), - input_length: entry.input_length as u32, - parameters: Some((&entry.request.parameters).into()), - stopping_parameters: Some(entry.request.parameters.clone().into()), + input_length: entry.request.input_length, + parameters: Some(entry.request.parameters.clone()), + stopping_parameters: Some(entry.request.stopping_parameters.clone()), }); ids.push(*id); @@ -162,25 +158,3 @@ impl Db { None } } - -impl From<&GenerateParameters> for NextTokenChooserParameters { - fn from(parameters: &GenerateParameters) -> Self { - Self { - temperature: parameters.temperature, - top_k: parameters.top_k as u32, - top_p: parameters.top_p, - do_sample: parameters.do_sample, - // FIXME: remove unwrap - seed: parameters.seed.unwrap(), - } - } -} - -impl From for StoppingCriteriaParameters { - fn from(parameters: GenerateParameters) -> Self { - Self { - stop_sequences: parameters.stop, - max_new_tokens: parameters.max_new_tokens, - } - } -} diff --git a/router/src/infer.rs b/router/src/infer.rs index 4c4a7eb8..23e84265 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -78,16 +78,15 @@ impl Infer { let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?; // Validate request - let (input_length, validated_request) = self.validation.validate(request).await?; + let valid_request = self.validation.validate(request).await?; // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); // Append the request to the database self.db.append(Entry { - request: validated_request, + request: valid_request, response_tx, - input_length, time: Instant::now(), batch_time: None, _permit: permit, diff --git a/router/src/validation.rs b/router/src/validation.rs index d9579774..99a5a975 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,7 +1,8 @@ /// Payload validation logic -use crate::GenerateRequest; +use crate::{GenerateParameters, GenerateRequest}; use rand::rngs::ThreadRng; use rand::Rng; +use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; @@ -38,7 +39,7 @@ impl Validation { pub(crate) async fn validate( &self, request: GenerateRequest, - ) -> Result<(usize, GenerateRequest), ValidationError> { + ) -> Result { // Create response channel let (sender, receiver) = oneshot::channel(); // Send request to the background validation task @@ -104,11 +105,11 @@ fn validation_worker( } fn validate( - mut request: GenerateRequest, + request: GenerateRequest, tokenizer: &Tokenizer, max_input_length: usize, rng: &mut ThreadRng, -) -> Result<(usize, GenerateRequest), ValidationError> { +) -> Result { if request.parameters.temperature <= 0.0 { return Err(ValidationError::Temperature); } @@ -129,19 +130,48 @@ fn validate( } // If seed is None, assign a random one - if request.parameters.seed.is_none() { - request.parameters.seed = Some(rng.gen()); - } + let seed = match request.parameters.seed { + None => rng.gen(), + Some(seed) => seed, + }; // Get the number of tokens in the input match tokenizer.encode(request.inputs.clone(), true) { - Ok(inputs) => { - let input_length = inputs.len(); + Ok(encoding) => { + let input_length = encoding.len(); if input_length > max_input_length { Err(ValidationError::InputLength(input_length, max_input_length)) } else { - Ok((input_length, request)) + // Return ValidGenerateRequest + let GenerateParameters { + temperature, + top_k, + top_p, + do_sample, + max_new_tokens, + stop: stop_sequences, + .. + } = request.parameters; + + let parameters = NextTokenChooserParameters { + temperature, + top_k, + top_p, + do_sample, + seed, + }; + let stopping_parameters = StoppingCriteriaParameters { + max_new_tokens, + stop_sequences, + }; + + Ok(ValidGenerateRequest { + inputs: request.inputs, + input_length: input_length as u32, + parameters, + stopping_parameters, + }) } } Err(err) => Err(ValidationError::Tokenizer(err.to_string())), @@ -150,9 +180,17 @@ fn validate( type ValidationRequest = ( GenerateRequest, - oneshot::Sender>, + oneshot::Sender>, ); +#[derive(Debug)] +pub(crate) struct ValidGenerateRequest { + pub inputs: String, + pub input_length: u32, + pub parameters: NextTokenChooserParameters, + pub stopping_parameters: StoppingCriteriaParameters, +} + #[derive(Error, Debug)] pub enum ValidationError { #[error("temperature must be strictly positive")]