mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
use Rust type system to validate logic
This commit is contained in:
parent
614a1a7202
commit
d5ab76cdfb
@ -31,7 +31,7 @@ message NextTokenChooserParameters {
|
|||||||
/// exponential scaling output probability distribution
|
/// exponential scaling output probability distribution
|
||||||
float temperature = 1;
|
float temperature = 1;
|
||||||
/// restricting to the k highest probability elements
|
/// restricting to the k highest probability elements
|
||||||
uint32 top_k = 2;
|
int32 top_k = 2;
|
||||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
float top_p = 3;
|
float top_p = 3;
|
||||||
/// apply sampling on the logits
|
/// apply sampling on the logits
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::infer::InferStreamResponse;
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{Batch, Request};
|
||||||
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use tokio::sync::mpsc::UnboundedSender;
|
use tokio::sync::mpsc::UnboundedSender;
|
||||||
use tokio::sync::OwnedSemaphorePermit;
|
use tokio::sync::OwnedSemaphorePermit;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -17,11 +15,9 @@ use tokio::time::Instant;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct Entry {
|
pub(crate) struct Entry {
|
||||||
/// Request
|
/// Request
|
||||||
pub request: GenerateRequest,
|
pub request: ValidGenerateRequest,
|
||||||
/// Response sender to communicate between the Infer struct and the batching_task
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Number of tokens in the input
|
|
||||||
pub input_length: usize,
|
|
||||||
/// Instant when this entry was created
|
/// Instant when this entry was created
|
||||||
pub time: Instant,
|
pub time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
@ -75,9 +71,9 @@ impl State {
|
|||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: *id,
|
id: *id,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.clone(),
|
||||||
input_length: entry.input_length as u32,
|
input_length: entry.request.input_length,
|
||||||
parameters: Some((&entry.request.parameters).into()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.parameters.clone().into()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
});
|
});
|
||||||
|
|
||||||
ids.push(*id);
|
ids.push(*id);
|
||||||
@ -162,25 +158,3 @@ impl Db {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&GenerateParameters> for NextTokenChooserParameters {
|
|
||||||
fn from(parameters: &GenerateParameters) -> Self {
|
|
||||||
Self {
|
|
||||||
temperature: parameters.temperature,
|
|
||||||
top_k: parameters.top_k as u32,
|
|
||||||
top_p: parameters.top_p,
|
|
||||||
do_sample: parameters.do_sample,
|
|
||||||
// FIXME: remove unwrap
|
|
||||||
seed: parameters.seed.unwrap(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<GenerateParameters> for StoppingCriteriaParameters {
|
|
||||||
fn from(parameters: GenerateParameters) -> Self {
|
|
||||||
Self {
|
|
||||||
stop_sequences: parameters.stop,
|
|
||||||
max_new_tokens: parameters.max_new_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -78,16 +78,15 @@ impl Infer {
|
|||||||
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
|
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
let (input_length, validated_request) = self.validation.validate(request).await?;
|
let valid_request = self.validation.validate(request).await?;
|
||||||
|
|
||||||
// MPSC channel to communicate with the background batching task
|
// MPSC channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
// Append the request to the database
|
// Append the request to the database
|
||||||
self.db.append(Entry {
|
self.db.append(Entry {
|
||||||
request: validated_request,
|
request: valid_request,
|
||||||
response_tx,
|
response_tx,
|
||||||
input_length,
|
|
||||||
time: Instant::now(),
|
time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit,
|
_permit: permit,
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::GenerateRequest;
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::rngs::ThreadRng;
|
use rand::rngs::ThreadRng;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
@ -38,7 +39,7 @@ impl Validation {
|
|||||||
pub(crate) async fn validate(
|
pub(crate) async fn validate(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<(usize, GenerateRequest), ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
let (sender, receiver) = oneshot::channel();
|
let (sender, receiver) = oneshot::channel();
|
||||||
// Send request to the background validation task
|
// Send request to the background validation task
|
||||||
@ -104,11 +105,11 @@ fn validation_worker(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn validate(
|
fn validate(
|
||||||
mut request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
rng: &mut ThreadRng,
|
rng: &mut ThreadRng,
|
||||||
) -> Result<(usize, GenerateRequest), ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
if request.parameters.temperature <= 0.0 {
|
if request.parameters.temperature <= 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
}
|
}
|
||||||
@ -129,19 +130,48 @@ fn validate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If seed is None, assign a random one
|
// If seed is None, assign a random one
|
||||||
if request.parameters.seed.is_none() {
|
let seed = match request.parameters.seed {
|
||||||
request.parameters.seed = Some(rng.gen());
|
None => rng.gen(),
|
||||||
}
|
Some(seed) => seed,
|
||||||
|
};
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
match tokenizer.encode(request.inputs.clone(), true) {
|
match tokenizer.encode(request.inputs.clone(), true) {
|
||||||
Ok(inputs) => {
|
Ok(encoding) => {
|
||||||
let input_length = inputs.len();
|
let input_length = encoding.len();
|
||||||
|
|
||||||
if input_length > max_input_length {
|
if input_length > max_input_length {
|
||||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||||
} else {
|
} else {
|
||||||
Ok((input_length, request))
|
// Return ValidGenerateRequest
|
||||||
|
let GenerateParameters {
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
stop: stop_sequences,
|
||||||
|
..
|
||||||
|
} = request.parameters;
|
||||||
|
|
||||||
|
let parameters = NextTokenChooserParameters {
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
do_sample,
|
||||||
|
seed,
|
||||||
|
};
|
||||||
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
|
max_new_tokens,
|
||||||
|
stop_sequences,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ValidGenerateRequest {
|
||||||
|
inputs: request.inputs,
|
||||||
|
input_length: input_length as u32,
|
||||||
|
parameters,
|
||||||
|
stopping_parameters,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
||||||
@ -150,9 +180,17 @@ fn validate(
|
|||||||
|
|
||||||
type ValidationRequest = (
|
type ValidationRequest = (
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
|
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct ValidGenerateRequest {
|
||||||
|
pub inputs: String,
|
||||||
|
pub input_length: u32,
|
||||||
|
pub parameters: NextTokenChooserParameters,
|
||||||
|
pub stopping_parameters: StoppingCriteriaParameters,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
#[error("temperature must be strictly positive")]
|
#[error("temperature must be strictly positive")]
|
||||||
|
Loading…
Reference in New Issue
Block a user