Removing a lot of NO_TOOL shenanigans.

This commit is contained in:
Nicolas Patry 2025-03-10 23:45:00 +01:00
parent cb92acf280
commit 03fe626a95
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
3 changed files with 112 additions and 86 deletions

View File

@ -19,14 +19,18 @@ struct Call {
} }
#[cfg_attr(test, derive(Debug))] #[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatEvent{ pub(crate) enum ChatEvent {
NoTool, NoTool,
Events(Vec<CompletionType>) Events(Vec<CompletionType>),
} }
pub(crate) fn parse_output( #[cfg_attr(test, derive(Debug))]
generated_text: &str, pub(crate) enum ChatChoice {
) -> Result<(Option<Vec<crate::ToolCall>>, Option<String>), InferError> { NoTool,
ToolCalls(Vec<crate::ToolCall>),
}
pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {
let call: Call = serde_json::from_str(generated_text).map_err(|e| { let call: Call = serde_json::from_str(generated_text).map_err(|e| {
InferError::ToolError(format!( InferError::ToolError(format!(
"Failed to parse generated text: {} {:?}", "Failed to parse generated text: {} {:?}",
@ -38,16 +42,7 @@ pub(crate) fn parse_output(
match &name[..] { match &name[..] {
"no_tool" => { "no_tool" => {
// parse the content message // parse the content message
let content_message = call Ok(ChatChoice::NoTool)
.function
.arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError("No `content` found in generated text".to_string())
})?
.to_string();
Ok((None, Some(content_message)))
} }
name => { name => {
let tool_calls = vec![crate::ToolCall { let tool_calls = vec![crate::ToolCall {
@ -63,7 +58,7 @@ pub(crate) fn parse_output(
})?, })?,
}, },
}]; }];
Ok((Some(tool_calls), None)) Ok(ChatChoice::ToolCalls(tool_calls))
} }
} }
} }
@ -194,8 +189,10 @@ impl ChatState {
match self.state { match self.state {
StreamState::Buffering => { StreamState::Buffering => {
self.text.push_str(token_text); self.text.push_str(token_text);
tracing::info!("Current text {:?}", self.text);
let partial = &self.text; let partial = &self.text;
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ','); let partial =
partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) { if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
// This can be no_tool before the content has been emitted // This can be no_tool before the content has been emitted
if call.function._name != "no_tool" { if call.function._name != "no_tool" {
@ -212,7 +209,7 @@ impl ChatState {
events.push(chat_complete); events.push(chat_complete);
self.state = StreamState::Tool; self.state = StreamState::Tool;
}else{ } else {
return ChatEvent::NoTool; return ChatEvent::NoTool;
} }
} }
@ -362,7 +359,7 @@ mod tests {
index: 0, index: 0,
details: None, details: None,
}); });
if let ChatEvent::Events(events) = events{ if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
match &events[0] { match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
@ -382,7 +379,7 @@ mod tests {
} }
_ => panic!("Unexpected chunk"), _ => panic!("Unexpected chunk"),
} }
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
@ -417,7 +414,7 @@ mod tests {
finish_reason: FinishReason::Length, finish_reason: FinishReason::Length,
}), }),
}); });
if let ChatEvent::Events(events) = events{ if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 2); assert_eq!(events.len(), 2);
match &events[0] { match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
@ -451,7 +448,7 @@ mod tests {
} }
_ => panic!("Unexpected chunk"), _ => panic!("Unexpected chunk"),
} }
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
@ -513,18 +510,18 @@ mod tests {
// Initial ignored output // Initial ignored output
for token in &tokens[..10] { for token in &tokens[..10] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{ if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}"); assert_eq!(events.len(), 0, "{events:?}");
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
// No tool output // No tool output
let events = chat_state.push(tokens[10].clone()); let events = chat_state.push(tokens[10].clone());
if let ChatEvent::NoTool = events{ if let ChatEvent::NoTool = events {
assert!(true); assert!(true);
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
@ -579,18 +576,18 @@ mod tests {
// Initial ignored output // Initial ignored output
for token in &tokens[..10] { for token in &tokens[..10] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{ if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}"); assert_eq!(events.len(), 0, "{events:?}");
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
// No tool output // No tool output
let events = chat_state.push(tokens[10].clone()); let events = chat_state.push(tokens[10].clone());
if let ChatEvent::NoTool = events{ if let ChatEvent::NoTool = events {
assert!(true); assert!(true);
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
@ -659,9 +656,9 @@ mod tests {
// Initial ignored output // Initial ignored output
for token in &tokens[..11] { for token in &tokens[..11] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{ if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}"); assert_eq!(events.len(), 0, "{events:?}");
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
@ -671,7 +668,7 @@ mod tests {
let mut output_name = String::new(); let mut output_name = String::new();
for token in &tokens[11..11 + 17] { for token in &tokens[11..11 + 17] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{ if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
let (name, arguments) = get_tool_call_content(&events[0]); let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name { if let Some(name) = name {
@ -679,7 +676,7 @@ mod tests {
output_name.push_str(&name); output_name.push_str(&name);
} }
output.push_str(arguments); output.push_str(arguments);
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }
@ -693,9 +690,9 @@ mod tests {
// No tool finish // No tool finish
for token in &tokens[11 + 17..] { for token in &tokens[11 + 17..] {
let events = chat_state.push(token.clone()); let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{ if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}"); assert_eq!(events.len(), 0, "{events:?}");
}else{ } else {
panic!("Expected chat events"); panic!("Expected chat events");
} }
} }

View File

@ -40,13 +40,13 @@ impl ToolGrammar {
), ),
arguments: json!({ arguments: json!({
"type": "object", "type": "object",
"properties": { // "properties": {
"content": { // "content": {
"type": "string", // "type": "string",
"description": "The response content", // "description": "The response content",
} // }
}, // },
"required": ["content"] // "required": ["content"]
}), }),
}, },
})) }))

View File

@ -1,4 +1,4 @@
use crate::chat::{ChatState, ChatEvent}; use crate::chat::{ChatChoice, ChatEvent, ChatState};
/// HTTP Server logic /// HTTP Server logic
use crate::config::Config; use crate::config::Config;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
@ -1178,8 +1178,13 @@ pub(crate) async fn chat_completions(
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream // switch on stream
if stream { if stream {
let (headers, response_stream) = let (headers, response_stream) = generate_stream_internal(
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; infer.clone(),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await;
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
@ -1194,10 +1199,10 @@ pub(crate) async fn chat_completions(
chat.response_format = None; chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer).unwrap(); chat.clone().try_into_generate(&infer).unwrap();
assert_eq!(using_tools, false); assert!(!using_tools);
let (_headers, response_stream2) = let (_headers, response_stream2) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs.clone(), id.clone()); state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
response_stream = Box::pin(response_stream2); response_stream = Box::pin(response_stream2);
} }
ChatEvent::Events(events) => { ChatEvent::Events(events) => {
@ -1219,8 +1224,13 @@ pub(crate) async fn chat_completions(
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, input_length, Json(generation)) = let (mut headers, mut input_length, Json(generation)) = generate_internal(
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; Extension(infer.clone()),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -1228,7 +1238,26 @@ pub(crate) async fn chat_completions(
.as_secs(); .as_secs();
let (tool_calls, output) = if using_tools { let (tool_calls, output) = if using_tools {
crate::chat::parse_output(&generation.generated_text)? match crate::chat::parse_output(&generation.generated_text)? {
ChatChoice::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer)?;
assert!(!using_tools);
let (headers_final, input_length_final, Json(generation)) = generate_internal(
Extension(infer),
compute_type,
Json(generate_request),
span,
)
.await?;
headers = headers_final;
input_length = input_length_final;
(None, Some(generation.generated_text))
}
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
}
} else { } else {
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };