diff --git a/Cargo.toml b/Cargo.toml index 4e3ad010..09a3a4b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference" [workspace.dependencies] base64 = "0.22.0" tokenizers = { version = "0.20.0", features = ["http"] } -hf-hub = { version = "0.4.1", features = ["tokio"] } +hf-hub = { version = "0.4.2", features = ["tokio"] } metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } diff --git a/router/src/chat.rs b/router/src/chat.rs index ac132e20..1a18f030 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -151,6 +151,7 @@ fn create_event_from_stream_token( )) } +#[derive(Debug)] enum StreamState { /// Before the tools was parsed Buffering, @@ -200,6 +201,7 @@ impl ChatState { pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec { let mut events = vec![]; let token_text = &stream_token.token.text; + println!("Got {token_text:?} - State {:?}", self.state); match self.state { StreamState::Buffering => { self.text.push_str(token_text); @@ -223,9 +225,9 @@ impl ChatState { // XXX Caution, here we do not postfix the quote, so that the current output // Is necessarily finished with quotes for us to be able to parse. let partial = &self.text; - let partial = partial.trim_end(); - let partial = partial.trim_end_matches(','); + let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ','); if let Ok(call) = serde_json::from_str::(&format!("{}}}}}", partial)) { + // This can be no_tool before the content has been emitted if call.function._name != "no_tool" { stream_token.token.text = "{".to_string(); let chat_complete = create_event_from_stream_token( @@ -279,30 +281,35 @@ impl ChatState { StreamState::NoToolFinish => {} StreamState::NoTool => { self.text.push_str(token_text); - if token_text.contains("\"") || token_text.contains("}") { - let total_text = &self.text; - let total_text = total_text.trim_end(); - let total_text = total_text.trim_end_matches('}'); - let total_text = total_text.trim_end(); - let total_text = total_text.trim_end_matches('"'); - if let Ok(value) = - serde_json::from_str::(&format!("{}\"}}}}", total_text)) - { - if !value.function.content.is_empty() { - let text = token_text.trim_end(); - let text = text.trim_end_matches('}'); - let mut text = text.trim_end(); + if token_text.contains("\"") { + let mut text = self + .text + .trim_end_matches(|c: char| c.is_whitespace() || c == '}'); + // Trim once + if text.ends_with("\"") { + // Verify we have actually trimmed something + // The opposite can happen if the model is outputting inline JSON. + text = &text[..text.len() - 1]; + if let Ok(_value) = + serde_json::from_str::(&format!("{}\"}}}}", text)) + { + let mut text = token_text + .trim_end_matches(|c: char| c.is_whitespace() || c == '}'); // Effectively trim_end_match('"', 1) // because we do not want to eventually trim finishing escaped quotes // {{"\"Something\""}} if text.ends_with("\"") { text = &text[..text.len() - 1]; } + println!("Detected end of content {text:?}"); stream_token.token.text = text.to_string(); self.state = StreamState::NoToolFinish; + println!("NNew state {:?}", self.state); } } } + // This escaping is usually inline json escaping and we can therefore remove it. + stream_token.token.text = stream_token.token.text.replace("\\", ""); let chat_complete = create_event_from_stream_token( &stream_token, self.logprobs, @@ -372,6 +379,52 @@ mod tests { use super::*; + fn get_text_content(event: &CompletionType) -> &String { + match event { + CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { + assert_eq!(choices.len(), 1); + if let ChatCompletionChoice { + delta: ChatCompletionDelta::Chat(TextMessage { content, .. }), + .. + } = &choices[0] + { + content + } else { + panic!("Expected plain message"); + } + } + _ => panic!("Unexpected chunk"), + } + } + + fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) { + match event { + CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { + assert_eq!(choices.len(), 1); + if let ChatCompletionChoice { + delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }), + .. + } = &choices[0] + { + assert_eq!(tool_calls.len(), 1); + let DeltaToolCall { + index, + id, + r#type, + function, + } = &tool_calls[0]; + assert_eq!(*index, 0); + assert_eq!(id, ""); + assert_eq!(r#type, "function"); + (function.name.as_ref(), &function.arguments) + } else { + panic!("Expected plain message"); + } + } + _ => panic!("Unexpected chunk"), + } + } + #[test] fn test_chat_stream() { let mut chat_state = ChatState::new( @@ -518,6 +571,83 @@ mod tests { "}".to_string(), "}".to_string(), ]; + let tokens: Vec<_> = tokens + .into_iter() + .map(|text| StreamResponse { + generated_text: None, + token: Token { + id: 42, + text: text.to_string(), + logprob: 0.0, + special: false, + }, + top_tokens: vec![], + index: 0, + details: None, + }) + .collect(); + + // Initial ignored output + for token in &tokens[..14] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 0); + } + + // No tool output + let mut output = String::new(); + for token in &tokens[14..14 + 7] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 1); + let content = get_text_content(&events[0]); + output.push_str(content); + } + + assert_eq!(output, "I am a helpful assistant!"); + + // No tool finish + for token in &tokens[14 + 7..] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 0); + } + } + + #[test] + fn test_chat_stream_tool_no_tool_many_quotes() { + let mut chat_state = ChatState::new( + true, + StreamOptions { + include_usage: true, + }, + "fingerprint".to_string(), + "model_id".to_string(), + false, + ); + + let tokens = vec![ + "{\"".to_string(), + "function".to_string(), + "\":".to_string(), + " {\"".to_string(), + "_".to_string(), + "name".to_string(), + "\":".to_string(), + " \"".to_string(), + "no".to_string(), + "_tool".to_string(), + "\",".to_string(), + " \"".to_string(), + "content".to_string(), + "\":".to_string(), + " \"".to_string(), // Token 14 + "I".to_string(), // Event 1 + " am".to_string(), // Event 2 + " a".to_string(), // Event 3 + " helpful".to_string(), // Event 4 + " assistant".to_string(), // Event 5 + "!\\\"\"".to_string(), // Extra inside the string quote that would get removed + "}".to_string(), + "}".to_string(), + ]; // Initial ignored output for text in &tokens[..14] { @@ -569,7 +699,7 @@ mod tests { } } - assert_eq!(output, "I am a helpful assistant!"); + assert_eq!(output, "I am a helpful assistant!\""); // No tool finish for text in &tokens[14 + 7..] { @@ -589,6 +719,157 @@ mod tests { } } + #[test] + fn test_chat_stream_tool_no_tool_inline_json() { + let mut chat_state = ChatState::new( + true, + StreamOptions { + include_usage: true, + }, + "fingerprint".to_string(), + "model_id".to_string(), + false, + ); + + let tokens = vec![ + "{\"".to_string(), + "function".to_string(), + "\":".to_string(), + " {\"".to_string(), + "_".to_string(), + "name".to_string(), + "\":".to_string(), + " \"".to_string(), + "no".to_string(), + "_tool".to_string(), + "\",".to_string(), + " \"".to_string(), + "content".to_string(), + "\":".to_string(), + " \"".to_string(), // Token 14 + "{\\\"".to_string(), // Event 1 + "a".to_string(), // Event 1 + "\\\":".to_string(), // Event 1 + "2".to_string(), // Event 2 + ",\\".to_string(), // Event 2 + "\"".to_string(), // Event 2 + "b".to_string(), // Event 3 + "\\\": ".to_string(), // Event 4 + "1".to_string(), // Event 5 + "}".to_string(), // Event 5 + "\"}".to_string(), // Extra inside the string quote that would get removed + "}".to_string(), + ]; + let tokens: Vec<_> = tokens + .into_iter() + .map(|text| StreamResponse { + generated_text: None, + token: Token { + id: 42, + text: text.to_string(), + logprob: 0.0, + special: false, + }, + top_tokens: vec![], + index: 0, + details: None, + }) + .collect(); + + // Initial ignored output + for token in &tokens[..14] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 0); + } + + // No tool output + let mut output = String::new(); + for token in &tokens[14..14 + 12] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 1, "Current text is {output:?}"); + let content = get_text_content(&events[0]); + output.push_str(content); + } + + assert_eq!(output, "{\"a\":2,\"b\": 1}"); + + // No tool finish + for token in &tokens[14 + 12..] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 0, "Extra events {events:?}"); + } + } + + #[test] + fn test_chat_stream_tool_no_tool_empty() { + let mut chat_state = ChatState::new( + true, + StreamOptions { + include_usage: true, + }, + "fingerprint".to_string(), + "model_id".to_string(), + false, + ); + + let tokens = vec![ + "{\"".to_string(), + "function".to_string(), + "\":".to_string(), + " {\"".to_string(), + "_".to_string(), + "name".to_string(), + "\":".to_string(), + " \"".to_string(), + "no".to_string(), + "_tool".to_string(), + "\",".to_string(), + " \"".to_string(), + "content".to_string(), + "\":\"".to_string(), + "\"}".to_string(), // Token 13 + "}".to_string(), // Event 1 + ]; + let tokens: Vec<_> = tokens + .into_iter() + .map(|text| StreamResponse { + generated_text: None, + token: Token { + id: 42, + text: text.to_string(), + logprob: 0.0, + special: false, + }, + top_tokens: vec![], + index: 0, + details: None, + }) + .collect(); + + // Initial ignored output + for token in &tokens[..13] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 0); + } + + // No tool output + let mut output = String::new(); + for token in &tokens[13..13 + 2] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 1, "Current text is {output:?}"); + let content = get_text_content(&events[0]); + output.push_str(content); + } + + assert_eq!(output, ""); + + // No tool finish + for token in &tokens[13 + 2..] { + let events = chat_state.push(token.clone()); + assert_eq!(events.len(), 0, "Extra events {events:?}"); + } + } + #[test] fn test_chat_stream_tool_get_weather() { let mut chat_state = ChatState::new( @@ -633,10 +914,9 @@ mod tests { "elsius".to_string(), // Event 17 "\"}}".to_string(), // Event 18 retained (trailing brace removed) ]; - - // Initial ignored output - for text in &tokens[..11] { - let events = chat_state.push(StreamResponse { + let tokens: Vec<_> = tokens + .into_iter() + .map(|text| StreamResponse { generated_text: None, token: Token { id: 42, @@ -647,56 +927,27 @@ mod tests { top_tokens: vec![], index: 0, details: None, - }); + }) + .collect(); + + // Initial ignored output + for token in &tokens[..11] { + let events = chat_state.push(token.clone()); assert_eq!(events.len(), 0, "{events:?}"); } // No tool output let mut output = String::new(); let mut output_name = String::new(); - for text in &tokens[11..11 + 17] { - let events = chat_state.push(StreamResponse { - generated_text: None, - token: Token { - id: 42, - text: text.to_string(), - logprob: 0.0, - special: false, - }, - top_tokens: vec![], - index: 0, - details: None, - }); + for token in &tokens[11..11 + 17] { + let events = chat_state.push(token.clone()); assert_eq!(events.len(), 1); - match &events[0] { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { - assert_eq!(choices.len(), 1); - if let ChatCompletionChoice { - delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }), - .. - } = &choices[0] - { - assert_eq!(tool_calls.len(), 1); - let DeltaToolCall { - index, - id, - r#type, - function, - } = &tool_calls[0]; - assert_eq!(*index, 0); - assert_eq!(id, ""); - assert_eq!(r#type, "function"); - if let Some(name) = &function.name { - assert_eq!(name, "get_current_weather"); - output_name.push_str(&name); - } - output.push_str(&function.arguments); - } else { - panic!("Expected plain message"); - } - } - _ => panic!("Unexpected chunk"), + let (name, arguments) = get_tool_call_content(&events[0]); + if let Some(name) = name { + assert_eq!(name, "get_current_weather"); + output_name.push_str(&name); } + output.push_str(arguments); } assert_eq!(output_name, "get_current_weather"); @@ -706,19 +957,8 @@ mod tests { ); // No tool finish - for text in &tokens[11 + 17..] { - let events = chat_state.push(StreamResponse { - generated_text: None, - token: Token { - id: 42, - text: text.to_string(), - logprob: 0.0, - special: false, - }, - top_tokens: vec![], - index: 0, - details: None, - }); + for token in &tokens[11 + 17..] { + let events = chat_state.push(token.clone()); assert_eq!(events.len(), 0); } } diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index b179dd4d..ebba5fd3 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -16,7 +16,7 @@ pub(crate) fn strftime_now(format_str: String) -> Result, bos_token: Option, diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 7eb8a41b..cdce8188 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -52,7 +52,7 @@ pub struct Infer { /// Request backend backend: Arc, /// Chat template - chat_template: Option, + pub(crate) chat_template: Option, /// Inference limit limit_concurrent_requests: Arc, /// Backend health diff --git a/router/src/server.rs b/router/src/server.rs index 824a23bb..d68353aa 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1162,6 +1162,8 @@ pub(crate) async fn chat_completions( logprobs, .. } = chat.clone(); + + tracing::debug!("Got chat_template {:?}", infer.chat_template); let (generate_request, using_tools): (GenerateRequest, bool) = chat.try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); @@ -1565,6 +1567,7 @@ pub async fn run( ) } Type::Cache(cache) => { + tracing::info!("Cache {cache:?}"); let repo = cache.repo(Repo::with_revision( tokenizer_name.to_string(), RepoType::Model, @@ -1581,6 +1584,7 @@ pub async fn run( }; // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}"); let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path { HubTokenizerConfig::from_file(filename)