mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
use Rust type system to validate logic
This commit is contained in:
parent
614a1a7202
commit
d5ab76cdfb
@ -31,7 +31,7 @@ message NextTokenChooserParameters {
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
/// restricting to the k highest probability elements
|
||||
uint32 top_k = 2;
|
||||
int32 top_k = 2;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float top_p = 3;
|
||||
/// apply sampling on the logits
|
||||
|
@ -1,14 +1,12 @@
|
||||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio::time::Instant;
|
||||
@ -17,11 +15,9 @@ use tokio::time::Instant;
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Entry {
|
||||
/// Request
|
||||
pub request: GenerateRequest,
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
/// Number of tokens in the input
|
||||
pub input_length: usize,
|
||||
/// Instant when this entry was created
|
||||
pub time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
@ -75,9 +71,9 @@ impl State {
|
||||
requests.push(Request {
|
||||
id: *id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
input_length: entry.input_length as u32,
|
||||
parameters: Some((&entry.request.parameters).into()),
|
||||
stopping_parameters: Some(entry.request.parameters.clone().into()),
|
||||
input_length: entry.request.input_length,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
});
|
||||
|
||||
ids.push(*id);
|
||||
@ -162,25 +158,3 @@ impl Db {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&GenerateParameters> for NextTokenChooserParameters {
|
||||
fn from(parameters: &GenerateParameters) -> Self {
|
||||
Self {
|
||||
temperature: parameters.temperature,
|
||||
top_k: parameters.top_k as u32,
|
||||
top_p: parameters.top_p,
|
||||
do_sample: parameters.do_sample,
|
||||
// FIXME: remove unwrap
|
||||
seed: parameters.seed.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GenerateParameters> for StoppingCriteriaParameters {
|
||||
fn from(parameters: GenerateParameters) -> Self {
|
||||
Self {
|
||||
stop_sequences: parameters.stop,
|
||||
max_new_tokens: parameters.max_new_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -78,16 +78,15 @@ impl Infer {
|
||||
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
|
||||
|
||||
// Validate request
|
||||
let (input_length, validated_request) = self.validation.validate(request).await?;
|
||||
let valid_request = self.validation.validate(request).await?;
|
||||
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the database
|
||||
self.db.append(Entry {
|
||||
request: validated_request,
|
||||
request: valid_request,
|
||||
response_tx,
|
||||
input_length,
|
||||
time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit,
|
||||
|
@ -1,7 +1,8 @@
|
||||
/// Payload validation logic
|
||||
use crate::GenerateRequest;
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::rngs::ThreadRng;
|
||||
use rand::Rng;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
@ -38,7 +39,7 @@ impl Validation {
|
||||
pub(crate) async fn validate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<(usize, GenerateRequest), ValidationError> {
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
// Create response channel
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
// Send request to the background validation task
|
||||
@ -104,11 +105,11 @@ fn validation_worker(
|
||||
}
|
||||
|
||||
fn validate(
|
||||
mut request: GenerateRequest,
|
||||
request: GenerateRequest,
|
||||
tokenizer: &Tokenizer,
|
||||
max_input_length: usize,
|
||||
rng: &mut ThreadRng,
|
||||
) -> Result<(usize, GenerateRequest), ValidationError> {
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
if request.parameters.temperature <= 0.0 {
|
||||
return Err(ValidationError::Temperature);
|
||||
}
|
||||
@ -129,19 +130,48 @@ fn validate(
|
||||
}
|
||||
|
||||
// If seed is None, assign a random one
|
||||
if request.parameters.seed.is_none() {
|
||||
request.parameters.seed = Some(rng.gen());
|
||||
}
|
||||
let seed = match request.parameters.seed {
|
||||
None => rng.gen(),
|
||||
Some(seed) => seed,
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
match tokenizer.encode(request.inputs.clone(), true) {
|
||||
Ok(inputs) => {
|
||||
let input_length = inputs.len();
|
||||
Ok(encoding) => {
|
||||
let input_length = encoding.len();
|
||||
|
||||
if input_length > max_input_length {
|
||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||
} else {
|
||||
Ok((input_length, request))
|
||||
// Return ValidGenerateRequest
|
||||
let GenerateParameters {
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
do_sample,
|
||||
seed,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
stop_sequences,
|
||||
};
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs: request.inputs,
|
||||
input_length: input_length as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
||||
@ -150,9 +180,17 @@ fn validate(
|
||||
|
||||
type ValidationRequest = (
|
||||
GenerateRequest,
|
||||
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
|
||||
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: String,
|
||||
pub input_length: u32,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidationError {
|
||||
#[error("temperature must be strictly positive")]
|
||||
|
Loading…
Reference in New Issue
Block a user