use Rust type system to validate logic

This commit is contained in:
OlivierDehaene 2023-01-31 16:47:06 +01:00
parent 614a1a7202
commit d5ab76cdfb
4 changed files with 58 additions and 47 deletions

View File

@ -31,7 +31,7 @@ message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
int32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// apply sampling on the logits

View File

@ -1,14 +1,12 @@
/// This code is massively inspired by Tokio mini-redis
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::{GenerateParameters, GenerateRequest};
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::{Batch, Request};
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::OwnedSemaphorePermit;
use tokio::time::Instant;
@ -17,11 +15,9 @@ use tokio::time::Instant;
#[derive(Debug)]
pub(crate) struct Entry {
/// Request
pub request: GenerateRequest,
pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Number of tokens in the input
pub input_length: usize,
/// Instant when this entry was created
pub time: Instant,
/// Instant when this entry was added to a batch
@ -75,9 +71,9 @@ impl State {
requests.push(Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some((&entry.request.parameters).into()),
stopping_parameters: Some(entry.request.parameters.clone().into()),
input_length: entry.request.input_length,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
ids.push(*id);
@ -162,25 +158,3 @@ impl Db {
None
}
}
impl From<&GenerateParameters> for NextTokenChooserParameters {
fn from(parameters: &GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
// FIXME: remove unwrap
seed: parameters.seed.unwrap(),
}
}
}
impl From<GenerateParameters> for StoppingCriteriaParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
stop_sequences: parameters.stop,
max_new_tokens: parameters.max_new_tokens,
}
}
}

View File

@ -78,16 +78,15 @@ impl Infer {
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
// Validate request
let (input_length, validated_request) = self.validation.validate(request).await?;
let valid_request = self.validation.validate(request).await?;
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the database
self.db.append(Entry {
request: validated_request,
request: valid_request,
response_tx,
input_length,
time: Instant::now(),
batch_time: None,
_permit: permit,

View File

@ -1,7 +1,8 @@
/// Payload validation logic
use crate::GenerateRequest;
use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng;
use rand::Rng;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
@ -38,7 +39,7 @@ impl Validation {
pub(crate) async fn validate(
&self,
request: GenerateRequest,
) -> Result<(usize, GenerateRequest), ValidationError> {
) -> Result<ValidGenerateRequest, ValidationError> {
// Create response channel
let (sender, receiver) = oneshot::channel();
// Send request to the background validation task
@ -104,11 +105,11 @@ fn validation_worker(
}
fn validate(
mut request: GenerateRequest,
request: GenerateRequest,
tokenizer: &Tokenizer,
max_input_length: usize,
rng: &mut ThreadRng,
) -> Result<(usize, GenerateRequest), ValidationError> {
) -> Result<ValidGenerateRequest, ValidationError> {
if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature);
}
@ -129,19 +130,48 @@ fn validate(
}
// If seed is None, assign a random one
if request.parameters.seed.is_none() {
request.parameters.seed = Some(rng.gen());
}
let seed = match request.parameters.seed {
None => rng.gen(),
Some(seed) => seed,
};
// Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) {
Ok(inputs) => {
let input_length = inputs.len();
Ok(encoding) => {
let input_length = encoding.len();
if input_length > max_input_length {
Err(ValidationError::InputLength(input_length, max_input_length))
} else {
Ok((input_length, request))
// Return ValidGenerateRequest
let GenerateParameters {
temperature,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
..
} = request.parameters;
let parameters = NextTokenChooserParameters {
temperature,
top_k,
top_p,
do_sample,
seed,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
};
Ok(ValidGenerateRequest {
inputs: request.inputs,
input_length: input_length as u32,
parameters,
stopping_parameters,
})
}
}
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
@ -150,9 +180,17 @@ fn validate(
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
);
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
}
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("temperature must be strictly positive")]