Removing a lot of NO_TOOL shenanigans.

This commit is contained in:
Nicolas Patry 2025-03-10 23:45:00 +01:00
parent cb92acf280
commit 03fe626a95
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
3 changed files with 112 additions and 86 deletions

View File

@ -19,14 +19,18 @@ struct Call {
}
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatEvent{
pub(crate) enum ChatEvent {
NoTool,
Events(Vec<CompletionType>)
Events(Vec<CompletionType>),
}
pub(crate) fn parse_output(
generated_text: &str,
) -> Result<(Option<Vec<crate::ToolCall>>, Option<String>), InferError> {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatChoice {
NoTool,
ToolCalls(Vec<crate::ToolCall>),
}
pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
InferError::ToolError(format!(
"Failed to parse generated text: {} {:?}",
@ -38,16 +42,7 @@ pub(crate) fn parse_output(
match &name[..] {
"no_tool" => {
// parse the content message
let content_message = call
.function
.arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError("No `content` found in generated text".to_string())
})?
.to_string();
Ok((None, Some(content_message)))
Ok(ChatChoice::NoTool)
}
name => {
let tool_calls = vec![crate::ToolCall {
@ -63,7 +58,7 @@ pub(crate) fn parse_output(
})?,
},
}];
Ok((Some(tool_calls), None))
Ok(ChatChoice::ToolCalls(tool_calls))
}
}
}
@ -194,8 +189,10 @@ impl ChatState {
match self.state {
StreamState::Buffering => {
self.text.push_str(token_text);
tracing::info!("Current text {:?}", self.text);
let partial = &self.text;
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
let partial =
partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
// This can be no_tool before the content has been emitted
if call.function._name != "no_tool" {
@ -212,7 +209,7 @@ impl ChatState {
events.push(chat_complete);
self.state = StreamState::Tool;
}else{
} else {
return ChatEvent::NoTool;
}
}
@ -362,7 +359,7 @@ mod tests {
index: 0,
details: None,
});
if let ChatEvent::Events(events) = events{
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 1);
match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
@ -382,7 +379,7 @@ mod tests {
}
_ => panic!("Unexpected chunk"),
}
}else{
} else {
panic!("Expected chat events");
}
}
@ -417,43 +414,43 @@ mod tests {
finish_reason: FinishReason::Length,
}),
});
if let ChatEvent::Events(events) = events{
assert_eq!(events.len(), 2);
match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(
choices,
&[ChatCompletionChoice {
index: 0,
delta: ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "Hi".to_string(),
tool_call_id: None,
}),
logprobs: None,
// HAS A FINISH REASON
finish_reason: Some("length".to_string()),
}]
);
}
_ => panic!("Unexpected chunk"),
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 2);
match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(
choices,
&[ChatCompletionChoice {
index: 0,
delta: ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "Hi".to_string(),
tool_call_id: None,
}),
logprobs: None,
// HAS A FINISH REASON
finish_reason: Some("length".to_string()),
}]
);
}
match &events[1] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => {
assert_eq!(
*usage,
Some(Usage {
prompt_tokens: 2,
completion_tokens: 10,
total_tokens: 12,
})
);
}
_ => panic!("Unexpected chunk"),
}
}else{
panic!("Expected chat events");
_ => panic!("Unexpected chunk"),
}
match &events[1] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => {
assert_eq!(
*usage,
Some(Usage {
prompt_tokens: 2,
completion_tokens: 10,
total_tokens: 12,
})
);
}
_ => panic!("Unexpected chunk"),
}
} else {
panic!("Expected chat events");
}
}
#[test]
@ -513,18 +510,18 @@ mod tests {
// Initial ignored output
for token in &tokens[..10] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
}else{
} else {
panic!("Expected chat events");
}
}
// No tool output
let events = chat_state.push(tokens[10].clone());
if let ChatEvent::NoTool = events{
if let ChatEvent::NoTool = events {
assert!(true);
}else{
} else {
panic!("Expected chat events");
}
}
@ -579,18 +576,18 @@ mod tests {
// Initial ignored output
for token in &tokens[..10] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
}else{
} else {
panic!("Expected chat events");
}
}
// No tool output
let events = chat_state.push(tokens[10].clone());
if let ChatEvent::NoTool = events{
if let ChatEvent::NoTool = events {
assert!(true);
}else{
} else {
panic!("Expected chat events");
}
}
@ -659,9 +656,9 @@ mod tests {
// Initial ignored output
for token in &tokens[..11] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
}else{
} else {
panic!("Expected chat events");
}
}
@ -671,7 +668,7 @@ mod tests {
let mut output_name = String::new();
for token in &tokens[11..11 + 17] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 1);
let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name {
@ -679,7 +676,7 @@ mod tests {
output_name.push_str(&name);
}
output.push_str(arguments);
}else{
} else {
panic!("Expected chat events");
}
}
@ -693,9 +690,9 @@ mod tests {
// No tool finish
for token in &tokens[11 + 17..] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events{
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
}else{
} else {
panic!("Expected chat events");
}
}

View File

@ -40,13 +40,13 @@ impl ToolGrammar {
),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
// "properties": {
// "content": {
// "type": "string",
// "description": "The response content",
// }
// },
// "required": ["content"]
}),
},
}))

View File

@ -1,4 +1,4 @@
use crate::chat::{ChatState, ChatEvent};
use crate::chat::{ChatChoice, ChatEvent, ChatState};
/// HTTP Server logic
use crate::config::Config;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
@ -1178,8 +1178,13 @@ pub(crate) async fn chat_completions(
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream
if stream {
let (headers, response_stream) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
let (headers, response_stream) = generate_stream_internal(
infer.clone(),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
@ -1192,12 +1197,12 @@ pub(crate) async fn chat_completions(
ChatEvent::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer).unwrap();
assert_eq!(using_tools, false);
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer).unwrap();
assert!(!using_tools);
let (_headers, response_stream2) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs.clone(), id.clone());
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
response_stream = Box::pin(response_stream2);
}
ChatEvent::Events(events) => {
@ -1219,8 +1224,13 @@ pub(crate) async fn chat_completions(
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let (mut headers, mut input_length, Json(generation)) = generate_internal(
Extension(infer.clone()),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -1228,7 +1238,26 @@ pub(crate) async fn chat_completions(
.as_secs();
let (tool_calls, output) = if using_tools {
crate::chat::parse_output(&generation.generated_text)?
match crate::chat::parse_output(&generation.generated_text)? {
ChatChoice::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer)?;
assert!(!using_tools);
let (headers_final, input_length_final, Json(generation)) = generate_internal(
Extension(infer),
compute_type,
Json(generate_request),
span,
)
.await?;
headers = headers_final;
input_length = input_length_final;
(None, Some(generation.generated_text))
}
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
}
} else {
(None, Some(generation.generated_text))
};