mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
feat(router): add max_total_tokens and empty_input validation
This commit is contained in:
parent
68455353f5
commit
bfdd8de903
@ -20,6 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer};
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "512", long, env)]
|
||||
max_max_new_tokens: u32,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "1512", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1000", long, env)]
|
||||
max_input_length: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
@ -46,7 +52,10 @@ fn main() -> Result<(), std::io::Error> {
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_max_new_tokens,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
@ -92,7 +101,10 @@ fn main() -> Result<(), std::io::Error> {
|
||||
// Run server
|
||||
server::run(
|
||||
max_concurrent_requests,
|
||||
max_max_new_tokens,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_waiting_tokens,
|
||||
sharded_client,
|
||||
|
@ -291,7 +291,10 @@ async fn generate_stream(
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
max_concurrent_requests: usize,
|
||||
max_max_new_tokens: u32,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: usize,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
@ -333,7 +336,14 @@ pub async fn run(
|
||||
struct ApiDoc;
|
||||
|
||||
// Create state
|
||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
max_max_new_tokens,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::validation::ValidationError::EmptyInput;
|
||||
/// Payload validation logic
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::rngs::ThreadRng;
|
||||
@ -8,9 +9,6 @@ use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{instrument, Span};
|
||||
|
||||
const MAX_MAX_NEW_TOKENS: u32 = 512;
|
||||
const MAX_STOP_SEQUENCES: usize = 4;
|
||||
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
@ -19,7 +17,14 @@ pub struct Validation {
|
||||
}
|
||||
|
||||
impl Validation {
|
||||
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
|
||||
pub(crate) fn new(
|
||||
workers: usize,
|
||||
tokenizer: Tokenizer,
|
||||
max_max_new_tokens: u32,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
||||
|
||||
@ -27,7 +32,10 @@ impl Validation {
|
||||
tokio::spawn(validation_task(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_max_new_tokens,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
validation_receiver,
|
||||
));
|
||||
|
||||
@ -61,7 +69,10 @@ impl Validation {
|
||||
async fn validation_task(
|
||||
workers: usize,
|
||||
tokenizer: Tokenizer,
|
||||
max_max_new_tokens: u32,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
) {
|
||||
let mut workers_senders = Vec::with_capacity(workers);
|
||||
@ -75,7 +86,14 @@ async fn validation_task(
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
|
||||
validation_worker(
|
||||
tokenizer_clone,
|
||||
max_max_new_tokens,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
worker_receiver,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
@ -95,7 +113,10 @@ async fn validation_task(
|
||||
/// the tokenizer
|
||||
fn validation_worker(
|
||||
tokenizer: Tokenizer,
|
||||
max_max_new_tokens: u32,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
) {
|
||||
// Seed rng
|
||||
@ -106,7 +127,16 @@ fn validation_worker(
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(
|
||||
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
|
||||
validate(
|
||||
request,
|
||||
&tokenizer,
|
||||
max_max_new_tokens,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
&mut rng,
|
||||
)
|
||||
.map_err(|err| {
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}),
|
||||
@ -119,7 +149,10 @@ fn validation_worker(
|
||||
fn validate(
|
||||
request: GenerateRequest,
|
||||
tokenizer: &Tokenizer,
|
||||
max_max_new_tokens: u32,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
rng: &mut ThreadRng,
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
let GenerateParameters {
|
||||
@ -161,13 +194,13 @@ fn validate(
|
||||
}
|
||||
}?;
|
||||
|
||||
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
||||
if max_new_tokens == 0 || max_new_tokens > max_max_new_tokens {
|
||||
return Err(ValidationError::MaxNewTokens(max_max_new_tokens));
|
||||
}
|
||||
|
||||
if stop_sequences.len() > MAX_STOP_SEQUENCES {
|
||||
if stop_sequences.len() > max_stop_sequences {
|
||||
return Err(ValidationError::StopSequence(
|
||||
MAX_STOP_SEQUENCES,
|
||||
max_stop_sequences,
|
||||
stop_sequences.len(),
|
||||
));
|
||||
}
|
||||
@ -178,13 +211,23 @@ fn validate(
|
||||
Some(seed) => seed,
|
||||
};
|
||||
|
||||
// Check if inputs is empty
|
||||
if request.inputs.is_empty() {
|
||||
return Err(EmptyInput);
|
||||
}
|
||||
|
||||
// Get the number of tokens in the input
|
||||
match tokenizer.encode(request.inputs.clone(), true) {
|
||||
Ok(encoding) => {
|
||||
let input_length = encoding.len();
|
||||
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
if input_length > max_input_length {
|
||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||
Err(ValidationError::InputLength(max_input_length, input_length))
|
||||
} else if total_tokens > max_total_tokens {
|
||||
Err(ValidationError::MaxTotalTokens(
|
||||
max_total_tokens,
|
||||
total_tokens,
|
||||
))
|
||||
} else {
|
||||
// Return ValidGenerateRequest
|
||||
let parameters = NextTokenChooserParameters {
|
||||
@ -238,8 +281,12 @@ pub enum ValidationError {
|
||||
TopK,
|
||||
#[error("max_new_tokens must be strictly positive and <= {0}")]
|
||||
MaxNewTokens(u32),
|
||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||
#[error("input tokens + max_new_tokens must be <= {0}. Given {1}")]
|
||||
MaxTotalTokens(usize, usize),
|
||||
#[error("inputs must have less than {0} tokens. Given: {1}")]
|
||||
InputLength(usize, usize),
|
||||
#[error("inputs cannot be empty")]
|
||||
EmptyInput,
|
||||
#[error("stop supports up to {0} stop sequences. Given: {1}")]
|
||||
StopSequence(usize, usize),
|
||||
#[error("tokenizer error {0}")]
|
||||
|
Loading…
Reference in New Issue
Block a user