Removing a lot of NO_TOOL shenanigans.

This commit is contained in:
Nicolas Patry 2025-03-10 23:45:00 +01:00
parent cb92acf280
commit 03fe626a95
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
3 changed files with 112 additions and 86 deletions

View File

@ -21,12 +21,16 @@ struct Call {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatEvent {
NoTool,
Events(Vec<CompletionType>)
Events(Vec<CompletionType>),
}
pub(crate) fn parse_output(
generated_text: &str,
) -> Result<(Option<Vec<crate::ToolCall>>, Option<String>), InferError> {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatChoice {
NoTool,
ToolCalls(Vec<crate::ToolCall>),
}
pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
InferError::ToolError(format!(
"Failed to parse generated text: {} {:?}",
@ -38,16 +42,7 @@ pub(crate) fn parse_output(
match &name[..] {
"no_tool" => {
// parse the content message
let content_message = call
.function
.arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError("No `content` found in generated text".to_string())
})?
.to_string();
Ok((None, Some(content_message)))
Ok(ChatChoice::NoTool)
}
name => {
let tool_calls = vec![crate::ToolCall {
@ -63,7 +58,7 @@ pub(crate) fn parse_output(
})?,
},
}];
Ok((Some(tool_calls), None))
Ok(ChatChoice::ToolCalls(tool_calls))
}
}
}
@ -194,8 +189,10 @@ impl ChatState {
match self.state {
StreamState::Buffering => {
self.text.push_str(token_text);
tracing::info!("Current text {:?}", self.text);
let partial = &self.text;
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
let partial =
partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
// This can be no_tool before the content has been emitted
if call.function._name != "no_tool" {

View File

@ -40,13 +40,13 @@ impl ToolGrammar {
),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
// "properties": {
// "content": {
// "type": "string",
// "description": "The response content",
// }
// },
// "required": ["content"]
}),
},
}))

View File

@ -1,4 +1,4 @@
use crate::chat::{ChatState, ChatEvent};
use crate::chat::{ChatChoice, ChatEvent, ChatState};
/// HTTP Server logic
use crate::config::Config;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
@ -1178,8 +1178,13 @@ pub(crate) async fn chat_completions(
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream
if stream {
let (headers, response_stream) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
let (headers, response_stream) = generate_stream_internal(
infer.clone(),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
@ -1194,10 +1199,10 @@ pub(crate) async fn chat_completions(
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer).unwrap();
assert_eq!(using_tools, false);
assert!(!using_tools);
let (_headers, response_stream2) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs.clone(), id.clone());
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
response_stream = Box::pin(response_stream2);
}
ChatEvent::Events(events) => {
@ -1219,8 +1224,13 @@ pub(crate) async fn chat_completions(
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let (mut headers, mut input_length, Json(generation)) = generate_internal(
Extension(infer.clone()),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -1228,7 +1238,26 @@ pub(crate) async fn chat_completions(
.as_secs();
let (tool_calls, output) = if using_tools {
crate::chat::parse_output(&generation.generated_text)?
match crate::chat::parse_output(&generation.generated_text)? {
ChatChoice::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer)?;
assert!(!using_tools);
let (headers_final, input_length_final, Json(generation)) = generate_internal(
Extension(infer),
compute_type,
Json(generate_request),
span,
)
.await?;
headers = headers_final;
input_length = input_length_final;
(None, Some(generation.generated_text))
}
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
}
} else {
(None, Some(generation.generated_text))
};