From f709466767eb51308d3dba4426c3956961b4bde2 Mon Sep 17 00:00:00 2001 From: datta0 Date: Fri, 24 Jan 2025 09:09:40 +0000 Subject: [PATCH] Make tool_call a list for streaming case --- router/src/lib.rs | 10 +++++----- router/src/server.rs | 25 ++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 414d38ed..19fa1579 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -727,7 +727,7 @@ pub(crate) struct ChatCompletionChoice { pub struct ToolCallDelta { #[schema(example = "assistant")] role: String, - tool_calls: DeltaToolCall, + tool_calls: Vec, // Changed to Vec } #[derive(Clone, Debug, Serialize, ToSchema)] @@ -770,15 +770,15 @@ impl ChatCompletionChunk { }), (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { role: "assistant".to_string(), - tool_calls: DeltaToolCall { - index: 0, + tool_calls: tool_calls.iter().enumerate().map(|(index, tool_call)| DeltaToolCall { + index: index as u32, id: String::new(), r#type: "function".to_string(), function: Function { name: None, - arguments: tool_calls[0].to_string(), + arguments: tool_call.to_string(), }, - }, + }).collect(), }), (None, None) => ChatCompletionDelta::Chat(TextMessage { role: "assistant".to_string(), diff --git a/router/src/server.rs b/router/src/server.rs index 29ceb8e6..9d2a509b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1138,7 +1138,27 @@ fn create_event_from_stream_token( // replace the content with the tool calls if grammar is present let (content, tool_calls) = if inner_using_tools { - (None, Some(vec![stream_token.token.text.clone()])) + // Create a DeltaToolCall object + let delta_tool_call = DeltaToolCall { + index: 0, // Assuming this is the first tool call + id: "0".to_string(), // Generate a unique ID here + r#type: "function".to_string(), + function: Function { + name: Some(stream_token.token.text.clone()), // Wrap in Some + arguments: stream_token.token.text.clone(), // Assuming the arguments are in the token text + }, + }; + + // Serialize the DeltaToolCall into a JSON string + let tool_call_string = serde_json::to_string(&delta_tool_call).unwrap_or_else(|e| { + println!("Failed to serialize DeltaToolCall: {:?}", e); + String::new() + }); + + // Wrap the serialized tool call in a Vec and Option + let tool_calls = Some(vec![tool_call_string]); + + (None, tool_calls) } else { let content = if !stream_token.token.special { Some(stream_token.token.text.clone()) @@ -1438,8 +1458,7 @@ pub(crate) async fn chat_completions( e )) })?; - println!("Arguments: {:?}", arguments); - println!("Arguments String: {:?}", arguments_string); + let tool_calls = vec![ToolCall { id: format!("{:09}", 0), r#type: "function".to_string(),