use Rust type system to validate logic

This commit is contained in:
OlivierDehaene 2023-01-31 16:47:06 +01:00
parent 614a1a7202
commit d5ab76cdfb
4 changed files with 58 additions and 47 deletions

View File

@ -31,7 +31,7 @@ message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
/// restricting to the k highest probability elements /// restricting to the k highest probability elements
uint32 top_k = 2; int32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3; float top_p = 3;
/// apply sampling on the logits /// apply sampling on the logits

View File

@ -1,14 +1,12 @@
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::infer::InferError; use crate::infer::InferError;
use crate::infer::InferStreamResponse; use crate::infer::InferStreamResponse;
use crate::{GenerateParameters, GenerateRequest}; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{Batch, Request};
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::OwnedSemaphorePermit; use tokio::sync::OwnedSemaphorePermit;
use tokio::time::Instant; use tokio::time::Instant;
@ -17,11 +15,9 @@ use tokio::time::Instant;
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Entry { pub(crate) struct Entry {
/// Request /// Request
pub request: GenerateRequest, pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Number of tokens in the input
pub input_length: usize,
/// Instant when this entry was created /// Instant when this entry was created
pub time: Instant, pub time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
@ -75,9 +71,9 @@ impl State {
requests.push(Request { requests.push(Request {
id: *id, id: *id,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32, input_length: entry.request.input_length,
parameters: Some((&entry.request.parameters).into()), parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.parameters.clone().into()), stopping_parameters: Some(entry.request.stopping_parameters.clone()),
}); });
ids.push(*id); ids.push(*id);
@ -162,25 +158,3 @@ impl Db {
None None
} }
} }
impl From<&GenerateParameters> for NextTokenChooserParameters {
fn from(parameters: &GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
// FIXME: remove unwrap
seed: parameters.seed.unwrap(),
}
}
}
impl From<GenerateParameters> for StoppingCriteriaParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
stop_sequences: parameters.stop,
max_new_tokens: parameters.max_new_tokens,
}
}
}

View File

@ -78,16 +78,15 @@ impl Infer {
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?; let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
// Validate request // Validate request
let (input_length, validated_request) = self.validation.validate(request).await?; let valid_request = self.validation.validate(request).await?;
// MPSC channel to communicate with the background batching task // MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the database // Append the request to the database
self.db.append(Entry { self.db.append(Entry {
request: validated_request, request: valid_request,
response_tx, response_tx,
input_length,
time: Instant::now(), time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, _permit: permit,

View File

@ -1,7 +1,8 @@
/// Payload validation logic /// Payload validation logic
use crate::GenerateRequest; use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng; use rand::rngs::ThreadRng;
use rand::Rng; use rand::Rng;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
@ -38,7 +39,7 @@ impl Validation {
pub(crate) async fn validate( pub(crate) async fn validate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<(usize, GenerateRequest), ValidationError> { ) -> Result<ValidGenerateRequest, ValidationError> {
// Create response channel // Create response channel
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
// Send request to the background validation task // Send request to the background validation task
@ -104,11 +105,11 @@ fn validation_worker(
} }
fn validate( fn validate(
mut request: GenerateRequest, request: GenerateRequest,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
max_input_length: usize, max_input_length: usize,
rng: &mut ThreadRng, rng: &mut ThreadRng,
) -> Result<(usize, GenerateRequest), ValidationError> { ) -> Result<ValidGenerateRequest, ValidationError> {
if request.parameters.temperature <= 0.0 { if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
} }
@ -129,19 +130,48 @@ fn validate(
} }
// If seed is None, assign a random one // If seed is None, assign a random one
if request.parameters.seed.is_none() { let seed = match request.parameters.seed {
request.parameters.seed = Some(rng.gen()); None => rng.gen(),
} Some(seed) => seed,
};
// Get the number of tokens in the input // Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) { match tokenizer.encode(request.inputs.clone(), true) {
Ok(inputs) => { Ok(encoding) => {
let input_length = inputs.len(); let input_length = encoding.len();
if input_length > max_input_length { if input_length > max_input_length {
Err(ValidationError::InputLength(input_length, max_input_length)) Err(ValidationError::InputLength(input_length, max_input_length))
} else { } else {
Ok((input_length, request)) // Return ValidGenerateRequest
let GenerateParameters {
temperature,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
..
} = request.parameters;
let parameters = NextTokenChooserParameters {
temperature,
top_k,
top_p,
do_sample,
seed,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
};
Ok(ValidGenerateRequest {
inputs: request.inputs,
input_length: input_length as u32,
parameters,
stopping_parameters,
})
} }
} }
Err(err) => Err(ValidationError::Tokenizer(err.to_string())), Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
@ -150,9 +180,17 @@ fn validate(
type ValidationRequest = ( type ValidationRequest = (
GenerateRequest, GenerateRequest,
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>, oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
); );
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
}
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ValidationError { pub enum ValidationError {
#[error("temperature must be strictly positive")] #[error("temperature must be strictly positive")]