Removing the no_tool content information.

This commit is contained in:
Nicolas Patry 2025-03-10 21:24:43 +01:00
parent f74c36fe0d
commit cb92acf280
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
2 changed files with 134 additions and 394 deletions

View File

@ -6,22 +6,6 @@ use crate::{
use serde::Deserialize; use serde::Deserialize;
use serde_json::Value; use serde_json::Value;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
enum _NoTool {
NoTool,
}
#[derive(Debug, Deserialize)]
struct NoToolCall {
_name: _NoTool,
content: String,
}
#[derive(Debug, Deserialize)]
struct NoTool {
function: NoToolCall,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ToolCall { struct ToolCall {
_name: String, _name: String,
@ -34,6 +18,12 @@ struct Call {
function: ToolCall, function: ToolCall,
} }
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatEvent{
NoTool,
Events(Vec<CompletionType>)
}
pub(crate) fn parse_output( pub(crate) fn parse_output(
generated_text: &str, generated_text: &str,
) -> Result<(Option<Vec<crate::ToolCall>>, Option<String>), InferError> { ) -> Result<(Option<Vec<crate::ToolCall>>, Option<String>), InferError> {
@ -158,10 +148,6 @@ enum StreamState {
Buffering, Buffering,
/// We detected a tool call here /// We detected a tool call here
Tool, Tool,
/// During the `content` part of the tool call
NoTool,
/// Finishing frames of the ToolCall
NoToolFinish,
/// This is without tool calling /// This is without tool calling
Content, Content,
} }
@ -202,32 +188,12 @@ impl ChatState {
} }
} }
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> { pub fn push(&mut self, mut stream_token: StreamResponse) -> ChatEvent {
let mut events = vec![]; let mut events = vec![];
let token_text = &stream_token.token.text; let token_text = &stream_token.token.text;
match self.state { match self.state {
StreamState::Buffering => { StreamState::Buffering => {
self.text.push_str(token_text); self.text.push_str(token_text);
// We have a special match for `no_tool` in order to capture directly the `content`
// key which should be re-emitted as raw text.
if let Ok(value) = serde_json::from_str::<NoTool>(&format!("{}\"}}}}", self.text)) {
self.state = StreamState::NoTool;
// Modifiy the content of the token to be whatever was captured by the JSON
stream_token.token.text = value.function.content;
let chat_complete = create_event_from_stream_token(
&stream_token,
self.logprobs,
false,
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
}
// XXX Caution, here we do not postfix the quote, so that the current output
// Is necessarily finished with quotes for us to be able to parse.
let partial = &self.text; let partial = &self.text;
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ','); let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) { if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
@ -246,6 +212,8 @@ impl ChatState {
events.push(chat_complete); events.push(chat_complete);
self.state = StreamState::Tool; self.state = StreamState::Tool;
}else{
return ChatEvent::NoTool;
} }
} }
} }
@ -282,50 +250,6 @@ impl ChatState {
events.push(chat_complete); events.push(chat_complete);
} }
} }
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
// We have remainder tokens, ignore everying,
StreamState::NoToolFinish => {}
StreamState::NoTool => {
self.text.push_str(token_text);
if token_text.contains("\"") {
let mut text = self
.text
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
// Trim once
if text.ends_with("\"") {
// Verify we have actually trimmed something
// The opposite can happen if the model is outputting inline JSON.
text = &text[..text.len() - 1];
if let Ok(_value) =
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", text))
{
let mut text = token_text
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
// Effectively trim_end_match('"', 1)
// because we do not want to eventually trim finishing escaped quotes
// {{"\"Something\""}}
if text.ends_with("\"") {
text = &text[..text.len() - 1];
}
stream_token.token.text = text.to_string();
self.state = StreamState::NoToolFinish;
}
}
}
// This escaping is usually inline json escaping and we can therefore remove it.
stream_token.token.text = stream_token.token.text.replace("\\", "");
let chat_complete = create_event_from_stream_token(
&stream_token,
self.logprobs,
false,
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
}
StreamState::Content => { StreamState::Content => {
let chat_complete = create_event_from_stream_token( let chat_complete = create_event_from_stream_token(
&stream_token, &stream_token,
@ -373,7 +297,7 @@ impl ChatState {
events.push(chat_complete); events.push(chat_complete);
} }
} }
events ChatEvent::Events(events)
} }
} }
@ -385,24 +309,6 @@ mod tests {
use super::*; use super::*;
fn get_text_content(event: &CompletionType) -> &String {
match event {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Chat(TextMessage { content, .. }),
..
} = &choices[0]
{
content
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
}
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) { fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
match event { match event {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
@ -456,6 +362,7 @@ mod tests {
index: 0, index: 0,
details: None, details: None,
}); });
if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
match &events[0] { match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
@ -475,6 +382,9 @@ mod tests {
} }
_ => panic!("Unexpected chunk"), _ => panic!("Unexpected chunk"),
} }
}else{
panic!("Expected chat events");
}
} }
#[test] #[test]
@ -507,6 +417,7 @@ mod tests {
finish_reason: FinishReason::Length, finish_reason: FinishReason::Length,
}), }),
}); });
if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 2); assert_eq!(events.len(), 2);
match &events[0] { match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
@ -540,10 +451,13 @@ mod tests {
} }
_ => panic!("Unexpected chunk"), _ => panic!("Unexpected chunk"),
} }
}else{
panic!("Expected chat events");
}
} }
#[test] #[test]
fn test_chat_stream_tool_no_tool() { fn test_chat_stream_tool_no_tool_simple() {
let mut chat_state = ChatState::new( let mut chat_state = ChatState::new(
true, true,
StreamOptions { StreamOptions {
@ -597,217 +511,21 @@ mod tests {
.collect(); .collect();
// Initial ignored output // Initial ignored output
for token in &tokens[..14] { for token in &tokens[..10] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0); if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 0, "{events:?}");
}else{
panic!("Expected chat events");
}
} }
// No tool output // No tool output
let mut output = String::new(); let events = chat_state.push(tokens[10].clone());
for token in &tokens[14..14 + 7] { if let ChatEvent::NoTool = events{
let events = chat_state.push(token.clone()); assert!(true);
assert_eq!(events.len(), 1); }else{
let content = get_text_content(&events[0]); panic!("Expected chat events");
output.push_str(content);
}
assert_eq!(output, "I am a helpful assistant!");
// No tool finish
for token in &tokens[14 + 7..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
}
#[test]
fn test_chat_stream_tool_no_tool_many_quotes() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":".to_string(),
" \"".to_string(), // Token 14
"I".to_string(), // Event 1
" am".to_string(), // Event 2
" a".to_string(), // Event 3
" helpful".to_string(), // Event 4
" assistant".to_string(), // Event 5
"!\\\"\"".to_string(), // Extra inside the string quote that would get removed
"}".to_string(),
"}".to_string(),
];
// Initial ignored output
for text in &tokens[..14] {
let events = chat_state.push(StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for text in &tokens[14..14 + 7] {
let events = chat_state.push(StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
assert_eq!(events.len(), 1);
match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Chat(TextMessage { content, .. }),
..
} = &choices[0]
{
output.push_str(content);
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
}
assert_eq!(output, "I am a helpful assistant!\"");
// No tool finish
for text in &tokens[14 + 7..] {
let events = chat_state.push(StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
assert_eq!(events.len(), 0);
}
}
#[test]
fn test_chat_stream_tool_no_tool_inline_json() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":".to_string(),
" \"".to_string(), // Token 14
"{\\\"".to_string(), // Event 1
"a".to_string(), // Event 1
"\\\":".to_string(), // Event 1
"2".to_string(), // Event 2
",\\".to_string(), // Event 2
"\"".to_string(), // Event 2
"b".to_string(), // Event 3
"\\\": ".to_string(), // Event 4
"1".to_string(), // Event 5
"}".to_string(), // Event 5
"\"}".to_string(), // Extra inside the string quote that would get removed
"}".to_string(),
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..14] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for token in &tokens[14..14 + 12] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1, "Current text is {output:?}");
let content = get_text_content(&events[0]);
output.push_str(content);
}
assert_eq!(output, "{\"a\":2,\"b\": 1}");
// No tool finish
for token in &tokens[14 + 12..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "Extra events {events:?}");
} }
} }
@ -859,26 +577,21 @@ mod tests {
.collect(); .collect();
// Initial ignored output // Initial ignored output
for token in &tokens[..13] { for token in &tokens[..10] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0); if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 0, "{events:?}");
}else{
panic!("Expected chat events");
}
} }
// No tool output // No tool output
let mut output = String::new(); let events = chat_state.push(tokens[10].clone());
for token in &tokens[13..13 + 2] { if let ChatEvent::NoTool = events{
let events = chat_state.push(token.clone()); assert!(true);
assert_eq!(events.len(), 1, "Current text is {output:?}"); }else{
let content = get_text_content(&events[0]); panic!("Expected chat events");
output.push_str(content);
}
assert_eq!(output, "");
// No tool finish
for token in &tokens[13 + 2..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "Extra events {events:?}");
} }
} }
@ -946,7 +659,11 @@ mod tests {
// Initial ignored output // Initial ignored output
for token in &tokens[..11] { for token in &tokens[..11] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 0, "{events:?}"); assert_eq!(events.len(), 0, "{events:?}");
}else{
panic!("Expected chat events");
}
} }
// No tool output // No tool output
@ -954,6 +671,7 @@ mod tests {
let mut output_name = String::new(); let mut output_name = String::new();
for token in &tokens[11..11 + 17] { for token in &tokens[11..11 + 17] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
let (name, arguments) = get_tool_call_content(&events[0]); let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name { if let Some(name) = name {
@ -961,6 +679,9 @@ mod tests {
output_name.push_str(&name); output_name.push_str(&name);
} }
output.push_str(arguments); output.push_str(arguments);
}else{
panic!("Expected chat events");
}
} }
assert_eq!(output_name, "get_current_weather"); assert_eq!(output_name, "get_current_weather");
@ -972,7 +693,11 @@ mod tests {
// No tool finish // No tool finish
for token in &tokens[11 + 17..] { for token in &tokens[11 + 17..] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0); if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 0, "{events:?}");
}else{
panic!("Expected chat events");
}
} }
} }
} }

View File

@ -1,4 +1,4 @@
use crate::chat::ChatState; use crate::chat::{ChatState, ChatEvent};
/// HTTP Server logic /// HTTP Server logic
use crate::config::Config; use crate::config::Config;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
@ -1151,7 +1151,7 @@ pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(chat): Json<ChatRequest>, Json(mut chat): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
@ -1166,7 +1166,7 @@ pub(crate) async fn chat_completions(
tracing::debug!("Got chat_template {:?}", infer.chat_template); tracing::debug!("Got chat_template {:?}", infer.chat_template);
let id = chat.next_tool_call_id(); let id = chat.next_tool_call_id();
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?; chat.clone().try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters)); span.record("parameters", format!("{:?}", generate_request.parameters));
let logprobs = logprobs.unwrap_or_default(); let logprobs = logprobs.unwrap_or_default();
@ -1179,15 +1179,28 @@ pub(crate) async fn chat_completions(
// switch on stream // switch on stream
if stream { if stream {
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id); let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result{ match result{
Ok(stream_token) => { Ok(stream_token) => {
let events = state.push(stream_token); let events = state.push(stream_token);
match events{
ChatEvent::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer).unwrap();
assert_eq!(using_tools, false);
let (_headers, response_stream2) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs.clone(), id.clone());
response_stream = Box::pin(response_stream2);
}
ChatEvent::Events(events) => {
for chat_complete in events{ for chat_complete in events{
yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| { yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| {
tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e); tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e);
@ -1195,6 +1208,8 @@ pub(crate) async fn chat_completions(
})); }));
} }
} }
}
}
Err(err) => yield Ok(err.into_openai_event()) Err(err) => yield Ok(err.into_openai_event())
} }
} }