diff --git a/router/src/chat.rs b/router/src/chat.rs index f1fd1948..d5824fea 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -19,14 +19,18 @@ struct Call { } #[cfg_attr(test, derive(Debug))] -pub(crate) enum ChatEvent{ +pub(crate) enum ChatEvent { NoTool, - Events(Vec) + Events(Vec), } -pub(crate) fn parse_output( - generated_text: &str, -) -> Result<(Option>, Option), InferError> { +#[cfg_attr(test, derive(Debug))] +pub(crate) enum ChatChoice { + NoTool, + ToolCalls(Vec), +} + +pub(crate) fn parse_output(generated_text: &str) -> Result { let call: Call = serde_json::from_str(generated_text).map_err(|e| { InferError::ToolError(format!( "Failed to parse generated text: {} {:?}", @@ -38,16 +42,7 @@ pub(crate) fn parse_output( match &name[..] { "no_tool" => { // parse the content message - let content_message = call - .function - .arguments - .get("content") - .and_then(Value::as_str) - .ok_or_else(|| { - InferError::ToolError("No `content` found in generated text".to_string()) - })? - .to_string(); - Ok((None, Some(content_message))) + Ok(ChatChoice::NoTool) } name => { let tool_calls = vec![crate::ToolCall { @@ -63,7 +58,7 @@ pub(crate) fn parse_output( })?, }, }]; - Ok((Some(tool_calls), None)) + Ok(ChatChoice::ToolCalls(tool_calls)) } } } @@ -194,8 +189,10 @@ impl ChatState { match self.state { StreamState::Buffering => { self.text.push_str(token_text); + tracing::info!("Current text {:?}", self.text); let partial = &self.text; - let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ','); + let partial = + partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}'); if let Ok(call) = serde_json::from_str::(&format!("{}}}}}", partial)) { // This can be no_tool before the content has been emitted if call.function._name != "no_tool" { @@ -212,7 +209,7 @@ impl ChatState { events.push(chat_complete); self.state = StreamState::Tool; - }else{ + } else { return ChatEvent::NoTool; } } @@ -362,7 +359,7 @@ mod tests { index: 0, details: None, }); - if let ChatEvent::Events(events) = events{ + if let ChatEvent::Events(events) = events { assert_eq!(events.len(), 1); match &events[0] { CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { @@ -382,7 +379,7 @@ mod tests { } _ => panic!("Unexpected chunk"), } - }else{ + } else { panic!("Expected chat events"); } } @@ -417,43 +414,43 @@ mod tests { finish_reason: FinishReason::Length, }), }); - if let ChatEvent::Events(events) = events{ - assert_eq!(events.len(), 2); - match &events[0] { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { - assert_eq!( - choices, - &[ChatCompletionChoice { - index: 0, - delta: ChatCompletionDelta::Chat(TextMessage { - role: "assistant".to_string(), - content: "Hi".to_string(), - tool_call_id: None, - }), - logprobs: None, - // HAS A FINISH REASON - finish_reason: Some("length".to_string()), - }] - ); - } - _ => panic!("Unexpected chunk"), + if let ChatEvent::Events(events) = events { + assert_eq!(events.len(), 2); + match &events[0] { + CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { + assert_eq!( + choices, + &[ChatCompletionChoice { + index: 0, + delta: ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: "Hi".to_string(), + tool_call_id: None, + }), + logprobs: None, + // HAS A FINISH REASON + finish_reason: Some("length".to_string()), + }] + ); } - match &events[1] { - CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => { - assert_eq!( - *usage, - Some(Usage { - prompt_tokens: 2, - completion_tokens: 10, - total_tokens: 12, - }) - ); - } - _ => panic!("Unexpected chunk"), - } - }else{ - panic!("Expected chat events"); + _ => panic!("Unexpected chunk"), } + match &events[1] { + CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => { + assert_eq!( + *usage, + Some(Usage { + prompt_tokens: 2, + completion_tokens: 10, + total_tokens: 12, + }) + ); + } + _ => panic!("Unexpected chunk"), + } + } else { + panic!("Expected chat events"); + } } #[test] @@ -513,18 +510,18 @@ mod tests { // Initial ignored output for token in &tokens[..10] { let events = chat_state.push(token.clone()); - if let ChatEvent::Events(events) = events{ + if let ChatEvent::Events(events) = events { assert_eq!(events.len(), 0, "{events:?}"); - }else{ + } else { panic!("Expected chat events"); } } // No tool output let events = chat_state.push(tokens[10].clone()); - if let ChatEvent::NoTool = events{ + if let ChatEvent::NoTool = events { assert!(true); - }else{ + } else { panic!("Expected chat events"); } } @@ -579,18 +576,18 @@ mod tests { // Initial ignored output for token in &tokens[..10] { let events = chat_state.push(token.clone()); - if let ChatEvent::Events(events) = events{ + if let ChatEvent::Events(events) = events { assert_eq!(events.len(), 0, "{events:?}"); - }else{ + } else { panic!("Expected chat events"); } } // No tool output let events = chat_state.push(tokens[10].clone()); - if let ChatEvent::NoTool = events{ + if let ChatEvent::NoTool = events { assert!(true); - }else{ + } else { panic!("Expected chat events"); } } @@ -659,9 +656,9 @@ mod tests { // Initial ignored output for token in &tokens[..11] { let events = chat_state.push(token.clone()); - if let ChatEvent::Events(events) = events{ + if let ChatEvent::Events(events) = events { assert_eq!(events.len(), 0, "{events:?}"); - }else{ + } else { panic!("Expected chat events"); } } @@ -671,7 +668,7 @@ mod tests { let mut output_name = String::new(); for token in &tokens[11..11 + 17] { let events = chat_state.push(token.clone()); - if let ChatEvent::Events(events) = events{ + if let ChatEvent::Events(events) = events { assert_eq!(events.len(), 1); let (name, arguments) = get_tool_call_content(&events[0]); if let Some(name) = name { @@ -679,7 +676,7 @@ mod tests { output_name.push_str(&name); } output.push_str(arguments); - }else{ + } else { panic!("Expected chat events"); } } @@ -693,9 +690,9 @@ mod tests { // No tool finish for token in &tokens[11 + 17..] { let events = chat_state.push(token.clone()); - if let ChatEvent::Events(events) = events{ + if let ChatEvent::Events(events) = events { assert_eq!(events.len(), 0, "{events:?}"); - }else{ + } else { panic!("Expected chat events"); } } diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 7770cd9d..e4e20859 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -40,13 +40,13 @@ impl ToolGrammar { ), arguments: json!({ "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] + // "properties": { + // "content": { + // "type": "string", + // "description": "The response content", + // } + // }, + // "required": ["content"] }), }, })) diff --git a/router/src/server.rs b/router/src/server.rs index 2e29fb80..7c6c7e01 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,4 +1,4 @@ -use crate::chat::{ChatState, ChatEvent}; +use crate::chat::{ChatChoice, ChatEvent, ChatState}; /// HTTP Server logic use crate::config::Config; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; @@ -1178,8 +1178,13 @@ pub(crate) async fn chat_completions( let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); // switch on stream if stream { - let (headers, response_stream) = - generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; + let (headers, response_stream) = generate_stream_internal( + infer.clone(), + compute_type.clone(), + Json(generate_request), + span.clone(), + ) + .await; let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); @@ -1192,12 +1197,12 @@ pub(crate) async fn chat_completions( ChatEvent::NoTool => { chat.tools = None; chat.response_format = None; - let (generate_request, using_tools): (GenerateRequest, bool) = - chat.clone().try_into_generate(&infer).unwrap(); - assert_eq!(using_tools, false); + let (generate_request, using_tools): (GenerateRequest, bool) = + chat.clone().try_into_generate(&infer).unwrap(); + assert!(!using_tools); let (_headers, response_stream2) = generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; - state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs.clone(), id.clone()); + state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); response_stream = Box::pin(response_stream2); } ChatEvent::Events(events) => { @@ -1219,8 +1224,13 @@ pub(crate) async fn chat_completions( let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, input_length, Json(generation)) = - generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; + let (mut headers, mut input_length, Json(generation)) = generate_internal( + Extension(infer.clone()), + compute_type.clone(), + Json(generate_request), + span.clone(), + ) + .await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -1228,7 +1238,26 @@ pub(crate) async fn chat_completions( .as_secs(); let (tool_calls, output) = if using_tools { - crate::chat::parse_output(&generation.generated_text)? + match crate::chat::parse_output(&generation.generated_text)? { + ChatChoice::NoTool => { + chat.tools = None; + chat.response_format = None; + let (generate_request, using_tools): (GenerateRequest, bool) = + chat.clone().try_into_generate(&infer)?; + assert!(!using_tools); + let (headers_final, input_length_final, Json(generation)) = generate_internal( + Extension(infer), + compute_type, + Json(generate_request), + span, + ) + .await?; + headers = headers_final; + input_length = input_length_final; + (None, Some(generation.generated_text)) + } + ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None), + } } else { (None, Some(generation.generated_text)) };