diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 81c0d38f..240282d9 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -120,10 +120,11 @@ impl Infer { ) -> Result, InferError> { // Tokenize request let inputs = request.inputs; + let add_special_tokens = request.add_special_tokens; let truncate = request.parameters.truncate; let encoding = self .validation - .tokenize(inputs, truncate) + .tokenize(inputs, add_special_tokens, truncate) .await .map_err(|err| { tracing::error!("Tokenization {err}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index 5a9779d5..fd07840b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1082,6 +1082,16 @@ pub(crate) struct GenerateRequest { pub inputs: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, + + /// This is used internally because some requests + /// already contain the templated input therefore + /// we shouldn't add the special tokens. + #[serde(default = "default_true")] + pub add_special_tokens: bool, +} + +fn default_true() -> bool { + true } #[derive(Clone, Debug, Deserialize, ToSchema)] @@ -1099,6 +1109,7 @@ impl From for GenerateRequest { fn from(req: CompatGenerateRequest) -> Self { Self { inputs: req.inputs, + add_special_tokens: true, parameters: req.parameters, } } diff --git a/router/src/server.rs b/router/src/server.rs index 8ebd1a33..f273a786 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -158,6 +158,7 @@ async fn get_chat_tokenize( let generate_request = GenerateRequest { inputs, + add_special_tokens: false, parameters: GenerateParameters { best_of: None, temperature, @@ -754,6 +755,7 @@ async fn completions( .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), + add_special_tokens: true, parameters: GenerateParameters { best_of: None, temperature, @@ -1180,6 +1182,7 @@ async fn chat_completions( // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), + add_special_tokens: false, parameters: GenerateParameters { best_of: None, temperature, @@ -1386,6 +1389,7 @@ async fn vertex_compatibility( .map(|instance| { let generate_request = GenerateRequest { inputs: instance.inputs.clone(), + add_special_tokens: true, parameters: GenerateParameters { do_sample: true, max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), diff --git a/router/src/validation.rs b/router/src/validation.rs index 0024723c..3c2e706b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -95,6 +95,7 @@ impl Validation { pub async fn tokenize( &self, inputs: String, + add_special_tokens: bool, truncate: Option, ) -> Result)>, ValidationError> { // If we have a fast tokenizer @@ -104,7 +105,11 @@ impl Validation { // Send request to the background validation task // Unwrap is safe here sender - .send(((inputs, truncate), response_sender, Span::current())) + .send(( + (inputs, add_special_tokens, truncate), + response_sender, + Span::current(), + )) .unwrap(); // Await on response channel @@ -121,11 +126,15 @@ impl Validation { async fn validate_input( &self, inputs: String, + add_special_tokens: bool, truncate: Option, max_new_tokens: Option, ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer - if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { + if let Some((encoding, inputs)) = self + .tokenize(inputs.clone(), add_special_tokens, truncate) + .await? + { // Create response channel let input_length = if let Some(truncate) = truncate { std::cmp::min(encoding.len(), truncate) @@ -324,7 +333,12 @@ impl Validation { // Validate inputs let (inputs, input_ids, input_length, max_new_tokens) = self - .validate_input(request.inputs, truncate, max_new_tokens) + .validate_input( + request.inputs, + request.add_special_tokens, + truncate, + max_new_tokens, + ) .await?; // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar @@ -449,12 +463,15 @@ fn tokenizer_worker( mut receiver: mpsc::UnboundedReceiver, ) { // Loop over requests - while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { parent_span.in_scope(|| { response_tx .send(prepare_input( inputs, truncate, + add_special_tokens, &tokenizer, config.as_ref(), preprocessor_config.as_ref(), @@ -591,6 +608,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String { fn prepare_input( inputs: String, _truncate: Option, + add_special_tokens: bool, tokenizer: &Tokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, @@ -628,14 +646,14 @@ fn prepare_input( // Get the number of tokens in the input let encoding = tokenizer - .encode(tokenizer_query, true) + .encode(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; Ok((encoding, input_chunks)) } type TokenizerRequest = ( - (String, Option), + (String, bool, Option), oneshot::Sender), ValidationError>>, Span, ); @@ -826,7 +844,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { // Err(ValidationError::MaxNewTokens(1, 10)) => (), @@ -861,7 +879,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), @@ -895,6 +913,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { best_of: Some(2), do_sample: false, @@ -934,6 +953,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: Some(1.0), max_new_tokens: Some(5), @@ -949,6 +969,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: Some(0.99), max_new_tokens: Some(5), @@ -964,6 +985,7 @@ mod tests { let valid_request = validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: None, max_new_tokens: Some(5), @@ -1002,6 +1024,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(5), max_new_tokens: Some(5), @@ -1017,6 +1040,7 @@ mod tests { validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(4), max_new_tokens: Some(5), @@ -1029,6 +1053,7 @@ mod tests { validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(0), max_new_tokens: Some(5), @@ -1041,6 +1066,7 @@ mod tests { let valid_request = validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: None, max_new_tokens: Some(5), @@ -1089,6 +1115,7 @@ mod tests { let chunks = match validation .tokenize( format!("test![](data:image/gif;base64,{})", PIXEL_GIF), + true, None, ) .await @@ -1148,6 +1175,7 @@ mod tests { "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", PIXEL_GIF, PIXEL_GIF ), + true, None, ) .await diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 968eaf1d..a2c218eb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -266,6 +266,7 @@ class FlashCausalLMBatch(Batch): orig_input_length = len(tokenized_input) prefix_len = r.prefix_len + assert prefix_len <= orig_input_length if prefix_len == orig_input_length: assert prefix_len > 0 prefix_len -= 1 @@ -282,6 +283,7 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids + print(f"Prefix {prefix_len} - Orig {orig_input_length}") request_position_ids = torch.arange( prefix_len, orig_input_length, dtype=torch.int32 )