From cb92acf2803804200c59a3b55744467f6dd3eded Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 10 Mar 2025 21:24:43 +0100 Subject: [PATCH] Removing the no_tool content information. --- router/src/chat.rs | 493 ++++++++++--------------------------------- router/src/server.rs | 35 ++- 2 files changed, 134 insertions(+), 394 deletions(-) diff --git a/router/src/chat.rs b/router/src/chat.rs index 63bd53bf1..f1fd19484 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -6,22 +6,6 @@ use crate::{ use serde::Deserialize; use serde_json::Value; -#[derive(Debug, Deserialize)] -#[serde(rename_all = "snake_case")] -enum _NoTool { - NoTool, -} - -#[derive(Debug, Deserialize)] -struct NoToolCall { - _name: _NoTool, - content: String, -} -#[derive(Debug, Deserialize)] -struct NoTool { - function: NoToolCall, -} - #[derive(Debug, Deserialize)] struct ToolCall { _name: String, @@ -34,6 +18,12 @@ struct Call { function: ToolCall, } +#[cfg_attr(test, derive(Debug))] +pub(crate) enum ChatEvent{ + NoTool, + Events(Vec) +} + pub(crate) fn parse_output( generated_text: &str, ) -> Result<(Option>, Option), InferError> { @@ -158,10 +148,6 @@ enum StreamState { Buffering, /// We detected a tool call here Tool, - /// During the `content` part of the tool call - NoTool, - /// Finishing frames of the ToolCall - NoToolFinish, /// This is without tool calling Content, } @@ -202,32 +188,12 @@ impl ChatState { } } - pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec { + pub fn push(&mut self, mut stream_token: StreamResponse) -> ChatEvent { let mut events = vec![]; let token_text = &stream_token.token.text; match self.state { StreamState::Buffering => { self.text.push_str(token_text); - // We have a special match for `no_tool` in order to capture directly the `content` - // key which should be re-emitted as raw text. - if let Ok(value) = serde_json::from_str::(&format!("{}\"}}}}", self.text)) { - self.state = StreamState::NoTool; - // Modifiy the content of the token to be whatever was captured by the JSON - stream_token.token.text = value.function.content; - let chat_complete = create_event_from_stream_token( - &stream_token, - self.logprobs, - false, - self.fingerprint.clone(), - self.model_id.clone(), - None, - self.id.clone(), - ); - - events.push(chat_complete); - } - // XXX Caution, here we do not postfix the quote, so that the current output - // Is necessarily finished with quotes for us to be able to parse. let partial = &self.text; let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ','); if let Ok(call) = serde_json::from_str::(&format!("{}}}}}", partial)) { @@ -246,6 +212,8 @@ impl ChatState { events.push(chat_complete); self.state = StreamState::Tool; + }else{ + return ChatEvent::NoTool; } } } @@ -282,50 +250,6 @@ impl ChatState { events.push(chat_complete); } } - // if we skipped sending the buffer we need to avoid sending the following json key and quotes - // We have remainder tokens, ignore everying, - StreamState::NoToolFinish => {} - StreamState::NoTool => { - self.text.push_str(token_text); - if token_text.contains("\"") { - let mut text = self - .text - .trim_end_matches(|c: char| c.is_whitespace() || c == '}'); - // Trim once - if text.ends_with("\"") { - // Verify we have actually trimmed something - // The opposite can happen if the model is outputting inline JSON. - text = &text[..text.len() - 1]; - if let Ok(_value) = - serde_json::from_str::(&format!("{}\"}}}}", text)) - { - let mut text = token_text - .trim_end_matches(|c: char| c.is_whitespace() || c == '}'); - // Effectively trim_end_match('"', 1) - // because we do not want to eventually trim finishing escaped quotes - // {{"\"Something\""}} - if text.ends_with("\"") { - text = &text[..text.len() - 1]; - } - stream_token.token.text = text.to_string(); - self.state = StreamState::NoToolFinish; - } - } - } - // This escaping is usually inline json escaping and we can therefore remove it. - stream_token.token.text = stream_token.token.text.replace("\\", ""); - let chat_complete = create_event_from_stream_token( - &stream_token, - self.logprobs, - false, - self.fingerprint.clone(), - self.model_id.clone(), - None, - self.id.clone(), - ); - - events.push(chat_complete); - } StreamState::Content => { let chat_complete = create_event_from_stream_token( &stream_token, @@ -373,7 +297,7 @@ impl ChatState { events.push(chat_complete); } } - events + ChatEvent::Events(events) } } @@ -385,24 +309,6 @@ mod tests { use super::*; - fn get_text_content(event: &CompletionType) -> &String { - match event { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { - assert_eq!(choices.len(), 1); - if let ChatCompletionChoice { - delta: ChatCompletionDelta::Chat(TextMessage { content, .. }), - .. - } = &choices[0] - { - content - } else { - panic!("Expected plain message"); - } - } - _ => panic!("Unexpected chunk"), - } - } - fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) { match event { CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { @@ -456,24 +362,28 @@ mod tests { index: 0, details: None, }); - assert_eq!(events.len(), 1); - match &events[0] { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { - assert_eq!( - choices, - &[ChatCompletionChoice { - index: 0, - delta: ChatCompletionDelta::Chat(TextMessage { - role: "assistant".to_string(), - content: "Hi".to_string(), - tool_call_id: None, - }), - logprobs: None, - finish_reason: None, - }] - ); + if let ChatEvent::Events(events) = events{ + assert_eq!(events.len(), 1); + match &events[0] { + CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { + assert_eq!( + choices, + &[ChatCompletionChoice { + index: 0, + delta: ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: "Hi".to_string(), + tool_call_id: None, + }), + logprobs: None, + finish_reason: None, + }] + ); + } + _ => panic!("Unexpected chunk"), } - _ => panic!("Unexpected chunk"), + }else{ + panic!("Expected chat events"); } } @@ -507,43 +417,47 @@ mod tests { finish_reason: FinishReason::Length, }), }); - assert_eq!(events.len(), 2); - match &events[0] { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { - assert_eq!( - choices, - &[ChatCompletionChoice { - index: 0, - delta: ChatCompletionDelta::Chat(TextMessage { - role: "assistant".to_string(), - content: "Hi".to_string(), - tool_call_id: None, - }), - logprobs: None, - // HAS A FINISH REASON - finish_reason: Some("length".to_string()), - }] - ); + if let ChatEvent::Events(events) = events{ + assert_eq!(events.len(), 2); + match &events[0] { + CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { + assert_eq!( + choices, + &[ChatCompletionChoice { + index: 0, + delta: ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: "Hi".to_string(), + tool_call_id: None, + }), + logprobs: None, + // HAS A FINISH REASON + finish_reason: Some("length".to_string()), + }] + ); + } + _ => panic!("Unexpected chunk"), + } + match &events[1] { + CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => { + assert_eq!( + *usage, + Some(Usage { + prompt_tokens: 2, + completion_tokens: 10, + total_tokens: 12, + }) + ); + } + _ => panic!("Unexpected chunk"), + } + }else{ + panic!("Expected chat events"); } - _ => panic!("Unexpected chunk"), - } - match &events[1] { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => { - assert_eq!( - *usage, - Some(Usage { - prompt_tokens: 2, - completion_tokens: 10, - total_tokens: 12, - }) - ); - } - _ => panic!("Unexpected chunk"), - } } #[test] - fn test_chat_stream_tool_no_tool() { + fn test_chat_stream_tool_no_tool_simple() { let mut chat_state = ChatState::new( true, StreamOptions { @@ -597,217 +511,21 @@ mod tests { .collect(); // Initial ignored output - for token in &tokens[..14] { + for token in &tokens[..10] { let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0); - } - - // No tool output - let mut output = String::new(); - for token in &tokens[14..14 + 7] { - let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 1); - let content = get_text_content(&events[0]); - output.push_str(content); - } - - assert_eq!(output, "I am a helpful assistant!"); - - // No tool finish - for token in &tokens[14 + 7..] { - let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0); - } - } - - #[test] - fn test_chat_stream_tool_no_tool_many_quotes() { - let mut chat_state = ChatState::new( - true, - StreamOptions { - include_usage: true, - }, - "fingerprint".to_string(), - "model_id".to_string(), - false, - "0".to_string(), - ); - - let tokens = vec![ - "{\"".to_string(), - "function".to_string(), - "\":".to_string(), - " {\"".to_string(), - "_".to_string(), - "name".to_string(), - "\":".to_string(), - " \"".to_string(), - "no".to_string(), - "_tool".to_string(), - "\",".to_string(), - " \"".to_string(), - "content".to_string(), - "\":".to_string(), - " \"".to_string(), // Token 14 - "I".to_string(), // Event 1 - " am".to_string(), // Event 2 - " a".to_string(), // Event 3 - " helpful".to_string(), // Event 4 - " assistant".to_string(), // Event 5 - "!\\\"\"".to_string(), // Extra inside the string quote that would get removed - "}".to_string(), - "}".to_string(), - ]; - - // Initial ignored output - for text in &tokens[..14] { - let events = chat_state.push(StreamResponse { - generated_text: None, - token: Token { - id: 42, - text: text.to_string(), - logprob: 0.0, - special: false, - }, - top_tokens: vec![], - index: 0, - details: None, - }); - assert_eq!(events.len(), 0); - } - - // No tool output - let mut output = String::new(); - for text in &tokens[14..14 + 7] { - let events = chat_state.push(StreamResponse { - generated_text: None, - token: Token { - id: 42, - text: text.to_string(), - logprob: 0.0, - special: false, - }, - top_tokens: vec![], - index: 0, - details: None, - }); - assert_eq!(events.len(), 1); - match &events[0] { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { - assert_eq!(choices.len(), 1); - if let ChatCompletionChoice { - delta: ChatCompletionDelta::Chat(TextMessage { content, .. }), - .. - } = &choices[0] - { - output.push_str(content); - } else { - panic!("Expected plain message"); - } - } - _ => panic!("Unexpected chunk"), + if let ChatEvent::Events(events) = events{ + assert_eq!(events.len(), 0, "{events:?}"); + }else{ + panic!("Expected chat events"); } } - assert_eq!(output, "I am a helpful assistant!\""); - - // No tool finish - for text in &tokens[14 + 7..] { - let events = chat_state.push(StreamResponse { - generated_text: None, - token: Token { - id: 42, - text: text.to_string(), - logprob: 0.0, - special: false, - }, - top_tokens: vec![], - index: 0, - details: None, - }); - assert_eq!(events.len(), 0); - } - } - - #[test] - fn test_chat_stream_tool_no_tool_inline_json() { - let mut chat_state = ChatState::new( - true, - StreamOptions { - include_usage: true, - }, - "fingerprint".to_string(), - "model_id".to_string(), - false, - "0".to_string(), - ); - - let tokens = vec![ - "{\"".to_string(), - "function".to_string(), - "\":".to_string(), - " {\"".to_string(), - "_".to_string(), - "name".to_string(), - "\":".to_string(), - " \"".to_string(), - "no".to_string(), - "_tool".to_string(), - "\",".to_string(), - " \"".to_string(), - "content".to_string(), - "\":".to_string(), - " \"".to_string(), // Token 14 - "{\\\"".to_string(), // Event 1 - "a".to_string(), // Event 1 - "\\\":".to_string(), // Event 1 - "2".to_string(), // Event 2 - ",\\".to_string(), // Event 2 - "\"".to_string(), // Event 2 - "b".to_string(), // Event 3 - "\\\": ".to_string(), // Event 4 - "1".to_string(), // Event 5 - "}".to_string(), // Event 5 - "\"}".to_string(), // Extra inside the string quote that would get removed - "}".to_string(), - ]; - let tokens: Vec<_> = tokens - .into_iter() - .map(|text| StreamResponse { - generated_text: None, - token: Token { - id: 42, - text: text.to_string(), - logprob: 0.0, - special: false, - }, - top_tokens: vec![], - index: 0, - details: None, - }) - .collect(); - - // Initial ignored output - for token in &tokens[..14] { - let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0); - } - // No tool output - let mut output = String::new(); - for token in &tokens[14..14 + 12] { - let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 1, "Current text is {output:?}"); - let content = get_text_content(&events[0]); - output.push_str(content); - } - - assert_eq!(output, "{\"a\":2,\"b\": 1}"); - - // No tool finish - for token in &tokens[14 + 12..] { - let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0, "Extra events {events:?}"); + let events = chat_state.push(tokens[10].clone()); + if let ChatEvent::NoTool = events{ + assert!(true); + }else{ + panic!("Expected chat events"); } } @@ -859,26 +577,21 @@ mod tests { .collect(); // Initial ignored output - for token in &tokens[..13] { + for token in &tokens[..10] { let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0); + if let ChatEvent::Events(events) = events{ + assert_eq!(events.len(), 0, "{events:?}"); + }else{ + panic!("Expected chat events"); + } } // No tool output - let mut output = String::new(); - for token in &tokens[13..13 + 2] { - let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 1, "Current text is {output:?}"); - let content = get_text_content(&events[0]); - output.push_str(content); - } - - assert_eq!(output, ""); - - // No tool finish - for token in &tokens[13 + 2..] { - let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0, "Extra events {events:?}"); + let events = chat_state.push(tokens[10].clone()); + if let ChatEvent::NoTool = events{ + assert!(true); + }else{ + panic!("Expected chat events"); } } @@ -946,7 +659,11 @@ mod tests { // Initial ignored output for token in &tokens[..11] { let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0, "{events:?}"); + if let ChatEvent::Events(events) = events{ + assert_eq!(events.len(), 0, "{events:?}"); + }else{ + panic!("Expected chat events"); + } } // No tool output @@ -954,13 +671,17 @@ mod tests { let mut output_name = String::new(); for token in &tokens[11..11 + 17] { let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 1); - let (name, arguments) = get_tool_call_content(&events[0]); - if let Some(name) = name { - assert_eq!(name, "get_current_weather"); - output_name.push_str(&name); + if let ChatEvent::Events(events) = events{ + assert_eq!(events.len(), 1); + let (name, arguments) = get_tool_call_content(&events[0]); + if let Some(name) = name { + assert_eq!(name, "get_current_weather"); + output_name.push_str(&name); + } + output.push_str(arguments); + }else{ + panic!("Expected chat events"); } - output.push_str(arguments); } assert_eq!(output_name, "get_current_weather"); @@ -972,7 +693,11 @@ mod tests { // No tool finish for token in &tokens[11 + 17..] { let events = chat_state.push(token.clone()); - assert_eq!(events.len(), 0); + if let ChatEvent::Events(events) = events{ + assert_eq!(events.len(), 0, "{events:?}"); + }else{ + panic!("Expected chat events"); + } } } } diff --git a/router/src/server.rs b/router/src/server.rs index 689b2f502..2e29fb80c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,4 +1,4 @@ -use crate::chat::ChatState; +use crate::chat::{ChatState, ChatEvent}; /// HTTP Server logic use crate::config::Config; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; @@ -1151,7 +1151,7 @@ pub(crate) async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, - Json(chat): Json, + Json(mut chat): Json, ) -> Result)> { let span = tracing::Span::current(); metrics::counter!("tgi_request_count").increment(1); @@ -1166,7 +1166,7 @@ pub(crate) async fn chat_completions( tracing::debug!("Got chat_template {:?}", infer.chat_template); let id = chat.next_tool_call_id(); let (generate_request, using_tools): (GenerateRequest, bool) = - chat.try_into_generate(&infer)?; + chat.clone().try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); let logprobs = logprobs.unwrap_or_default(); @@ -1179,20 +1179,35 @@ pub(crate) async fn chat_completions( // switch on stream if stream { let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(generate_request), span).await; + generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); - let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id); + let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); while let Some(result) = response_stream.next().await { match result{ Ok(stream_token) => { let events = state.push(stream_token); - for chat_complete in events{ - yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| { - tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e); - Event::default() - })); + match events{ + ChatEvent::NoTool => { + chat.tools = None; + chat.response_format = None; + let (generate_request, using_tools): (GenerateRequest, bool) = + chat.clone().try_into_generate(&infer).unwrap(); + assert_eq!(using_tools, false); + let (_headers, response_stream2) = + generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; + state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs.clone(), id.clone()); + response_stream = Box::pin(response_stream2); + } + ChatEvent::Events(events) => { + for chat_complete in events{ + yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| { + tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + })); + } + } } } Err(err) => yield Ok(err.into_openai_event())