mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
support left truncate
This commit is contained in:
parent
a376d8bc59
commit
d405880504
@ -63,7 +63,7 @@ pub(crate) struct GenerateParameters {
|
||||
pub stop: Vec<String>,
|
||||
#[serde(default)]
|
||||
#[schema(default = "null", example = "null")]
|
||||
pub truncate: Option<i32>,
|
||||
pub truncate: Option<usize>,
|
||||
#[serde(default)]
|
||||
#[schema(default = "false", example = true)]
|
||||
pub watermark: bool,
|
||||
|
@ -6,6 +6,7 @@ use rand::Rng;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{instrument, Span};
|
||||
|
||||
@ -157,6 +158,7 @@ fn validate(
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
truncate,
|
||||
seed,
|
||||
watermark,
|
||||
..
|
||||
@ -223,23 +225,45 @@ fn validate(
|
||||
return Err(EmptyInput);
|
||||
}
|
||||
|
||||
// Get the number of tokens in the input
|
||||
match tokenizer.encode(request.inputs.clone(), true) {
|
||||
Ok(mut encoding) => {
|
||||
encoding.truncate()
|
||||
// Check if truncate is strictly positive and less than max_input_length
|
||||
let truncate = truncate
|
||||
.map(|value| {
|
||||
if value == 0 || value > max_input_length {
|
||||
return Err(ValidationError::Truncate(max_input_length, value));
|
||||
}
|
||||
Ok(Some(value))
|
||||
})
|
||||
.unwrap_or(Ok(None))?;
|
||||
|
||||
let input_length = encoding.len();
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
// Get the number of tokens in the input
|
||||
let mut encoding = tokenizer
|
||||
.encode(request.inputs.clone(), true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
let (inputs, input_length) = if let Some(truncate) = truncate {
|
||||
// truncate encoding and decode new inputs
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
let inputs = tokenizer
|
||||
.decode(Vec::from(encoding.get_ids()), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
(inputs, encoding.len())
|
||||
} else {
|
||||
(request.inputs, encoding.len())
|
||||
};
|
||||
|
||||
if input_length > max_input_length {
|
||||
Err(ValidationError::InputLength(max_input_length, input_length))
|
||||
} else if total_tokens > max_total_tokens {
|
||||
Err(ValidationError::MaxTotalTokens(
|
||||
return Err(ValidationError::InputLength(max_input_length, input_length));
|
||||
}
|
||||
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
if total_tokens > max_total_tokens {
|
||||
return Err(ValidationError::MaxTotalTokens(
|
||||
max_total_tokens,
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
))
|
||||
} else {
|
||||
));
|
||||
}
|
||||
|
||||
// Return ValidGenerateRequest
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
@ -260,15 +284,11 @@ fn validate(
|
||||
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs: request.inputs,
|
||||
inputs,
|
||||
input_length: input_length as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
type ValidationRequest = (
|
||||
@ -295,6 +315,8 @@ pub enum ValidationError {
|
||||
TopP,
|
||||
#[error("`top_k` must be strictly positive")]
|
||||
TopK,
|
||||
#[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
|
||||
Truncate(usize, usize),
|
||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||
TypicalP,
|
||||
#[error("`max_new_tokens` must be strictly positive")]
|
||||
|
Loading…
Reference in New Issue
Block a user