Change add_special_tokens in order to have the correct tokens for chat

input and not (since it's super important with the prefixing now)
This commit is contained in:
Nicolas Patry 2024-08-27 11:51:29 +02:00
parent f1c0735453
commit 7f1816a4e1
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
5 changed files with 55 additions and 9 deletions

View File

@ -120,10 +120,11 @@ impl Infer {
) -> Result<Option<tokenizers::Encoding>, InferError> { ) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request // Tokenize request
let inputs = request.inputs; let inputs = request.inputs;
let add_special_tokens = request.add_special_tokens;
let truncate = request.parameters.truncate; let truncate = request.parameters.truncate;
let encoding = self let encoding = self
.validation .validation
.tokenize(inputs, truncate) .tokenize(inputs, add_special_tokens, truncate)
.await .await
.map_err(|err| { .map_err(|err| {
tracing::error!("Tokenization {err}"); tracing::error!("Tokenization {err}");

View File

@ -1082,6 +1082,16 @@ pub(crate) struct GenerateRequest {
pub inputs: String, pub inputs: String,
#[serde(default = "default_parameters")] #[serde(default = "default_parameters")]
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
/// This is used internally because some requests
/// already contain the templated input therefore
/// we shouldn't add the special tokens.
#[serde(default = "default_true")]
pub add_special_tokens: bool,
}
fn default_true() -> bool {
true
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
@ -1099,6 +1109,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self { fn from(req: CompatGenerateRequest) -> Self {
Self { Self {
inputs: req.inputs, inputs: req.inputs,
add_special_tokens: true,
parameters: req.parameters, parameters: req.parameters,
} }
} }

View File

@ -158,6 +158,7 @@ async fn get_chat_tokenize(
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs, inputs,
add_special_tokens: false,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -754,6 +755,7 @@ async fn completions(
.iter() .iter()
.map(|prompt| GenerateRequest { .map(|prompt| GenerateRequest {
inputs: prompt.to_string(), inputs: prompt.to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -1180,6 +1182,7 @@ async fn chat_completions(
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -1386,6 +1389,7 @@ async fn vertex_compatibility(
.map(|instance| { .map(|instance| {
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: instance.inputs.clone(), inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
do_sample: true, do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),

View File

@ -95,6 +95,7 @@ impl Validation {
pub async fn tokenize( pub async fn tokenize(
&self, &self,
inputs: String, inputs: String,
add_special_tokens: bool,
truncate: Option<usize>, truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> { ) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
@ -104,7 +105,11 @@ impl Validation {
// Send request to the background validation task // Send request to the background validation task
// Unwrap is safe here // Unwrap is safe here
sender sender
.send(((inputs, truncate), response_sender, Span::current())) .send((
(inputs, add_special_tokens, truncate),
response_sender,
Span::current(),
))
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
@ -121,11 +126,15 @@ impl Validation {
async fn validate_input( async fn validate_input(
&self, &self,
inputs: String, inputs: String,
add_special_tokens: bool,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self
.tokenize(inputs.clone(), add_special_tokens, truncate)
.await?
{
// Create response channel // Create response channel
let input_length = if let Some(truncate) = truncate { let input_length = if let Some(truncate) = truncate {
std::cmp::min(encoding.len(), truncate) std::cmp::min(encoding.len(), truncate)
@ -324,7 +333,12 @@ impl Validation {
// Validate inputs // Validate inputs
let (inputs, input_ids, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(
request.inputs,
request.add_special_tokens,
truncate,
max_new_tokens,
)
.await?; .await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
@ -449,12 +463,15 @@ fn tokenizer_worker(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) { ) {
// Loop over requests // Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input( .send(prepare_input(
inputs, inputs,
truncate, truncate,
add_special_tokens,
&tokenizer, &tokenizer,
config.as_ref(), config.as_ref(),
preprocessor_config.as_ref(), preprocessor_config.as_ref(),
@ -591,6 +608,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
fn prepare_input( fn prepare_input(
inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
add_special_tokens: bool,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: Option<&Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
@ -628,14 +646,14 @@ fn prepare_input(
// Get the number of tokens in the input // Get the number of tokens in the input
let encoding = tokenizer let encoding = tokenizer
.encode(tokenizer_query, true) .encode(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks)) Ok((encoding, input_chunks))
} }
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, bool, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span, Span,
); );
@ -826,7 +844,7 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await .await
{ {
// Err(ValidationError::MaxNewTokens(1, 10)) => (), // Err(ValidationError::MaxNewTokens(1, 10)) => (),
@ -861,7 +879,7 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await .await
{ {
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
@ -895,6 +913,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: Some(2), best_of: Some(2),
do_sample: false, do_sample: false,
@ -934,6 +953,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: Some(1.0), top_p: Some(1.0),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -949,6 +969,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: Some(0.99), top_p: Some(0.99),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -964,6 +985,7 @@ mod tests {
let valid_request = validation let valid_request = validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: None, top_p: None,
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1002,6 +1024,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(5), top_n_tokens: Some(5),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1017,6 +1040,7 @@ mod tests {
validation validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(4), top_n_tokens: Some(4),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1029,6 +1053,7 @@ mod tests {
validation validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(0), top_n_tokens: Some(0),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1041,6 +1066,7 @@ mod tests {
let valid_request = validation let valid_request = validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: None, top_n_tokens: None,
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1089,6 +1115,7 @@ mod tests {
let chunks = match validation let chunks = match validation
.tokenize( .tokenize(
format!("test![](data:image/gif;base64,{})", PIXEL_GIF), format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
true,
None, None,
) )
.await .await
@ -1148,6 +1175,7 @@ mod tests {
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF PIXEL_GIF, PIXEL_GIF
), ),
true,
None, None,
) )
.await .await

View File

@ -266,6 +266,7 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input) orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len prefix_len = r.prefix_len
assert prefix_len <= orig_input_length
if prefix_len == orig_input_length: if prefix_len == orig_input_length:
assert prefix_len > 0 assert prefix_len > 0
prefix_len -= 1 prefix_len -= 1
@ -282,6 +283,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
print(f"Prefix {prefix_len} - Orig {orig_input_length}")
request_position_ids = torch.arange( request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32 prefix_len, orig_input_length, dtype=torch.int32
) )