From 538456ba687fd6853abd874768014976651fb2ea Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 18 Feb 2025 21:13:03 +0000 Subject: [PATCH] fix: only send function name on first stream event --- .../test_openai_llama_tools.json | 2 +- ..._sea_creatures_stream_function_object.json | 2 +- ...r_tools_sea_creatures_stream_required.json | 2 +- ...test_flash_llama_grammar_tools_stream.json | 2 +- .../models/test_openai_llama_tools.py | 4 +- integration-tests/models/test_tools_llama.py | 6 +-- router/src/server.rs | 42 +++++++++++++++++-- 7 files changed, 48 insertions(+), 12 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json b/integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json index 764e9946..a9d575ee 100644 --- a/integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json +++ b/integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json @@ -20,7 +20,7 @@ "logprobs": null } ], - "created": 1739799458, + "created": 1739910558, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json index 7478c079..dc95942d 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -20,7 +20,7 @@ "logprobs": null } ], - "created": 1739797595, + "created": 1739910826, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json index 33b775e5..37687f05 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -21,7 +21,7 @@ "logprobs": null } ], - "created": 1739456930, + "created": 1739910816, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json index 211f6a42..f4b40bd4 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -21,7 +21,7 @@ "logprobs": null } ], - "created": 1739367874, + "created": 1739910803, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/test_openai_llama_tools.py b/integration-tests/models/test_openai_llama_tools.py index 051d4089..d5e109bf 100644 --- a/integration-tests/models/test_openai_llama_tools.py +++ b/integration-tests/models/test_openai_llama_tools.py @@ -101,7 +101,9 @@ async def test_openai_llama_tools(openai_llama_tools, response_snapshot): tool_call_string = "" for chunk in chat_completion: - tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments + function_call = chunk.choices[0].delta.tool_calls[0].function + if function_call: + tool_call_string += function_call.arguments last_chunk = chunk.to_dict() assert ( diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 68f6bfaa..b9376be7 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -216,7 +216,7 @@ async def test_flash_llama_grammar_tools_stream( assert response.choices[0].delta.content is None assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}' - assert count == 16 + assert count == 17 assert last_response == response_snapshot @@ -360,7 +360,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required( ) last_response = response - assert count == 23 + assert count == 24 assert ( tool_calls_generated == '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}' @@ -458,7 +458,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( tool_calls_generated += tool_call["function"]["arguments"] last_response = response - assert count == 25 + assert count == 26 assert ( tool_calls_generated == '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}' diff --git a/router/src/server.rs b/router/src/server.rs index f94046a8..1ca7cb7e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1305,6 +1305,40 @@ pub(crate) async fn chat_completions( state = StreamState::Content { skip_close_quote: false, }; + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta { + role: "assistant".to_string(), + tool_calls: vec![DeltaToolCall { + index: 0, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: Some(global_function_name.clone()), + arguments: "".to_string(), + }, + }], + }); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk{ + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![ChatCompletionChoice { + index: 0, + delta: tool_delta_start, + logprobs: None, + finish_reason: None, + }], + usage: None, + }); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); buffer.drain(1..); // only keep the first token (opening '{') buffer[0].token.text = buffer[0].token.text.chars().take(1).collect(); } @@ -1341,7 +1375,7 @@ pub(crate) async fn chat_completions( None, None, None, - Some(global_function_name.clone()), + None, )); yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { InferError::StreamSerializationError(e.to_string()).into() @@ -1370,7 +1404,7 @@ pub(crate) async fn chat_completions( response_as_tool, system_fingerprint.clone(), model_id.clone(), - Some(global_function_name.clone()), + None, ); yield Ok::(event); @@ -1394,7 +1428,7 @@ pub(crate) async fn chat_completions( response_as_tool, system_fingerprint.clone(), model_id.clone(), - Some(global_function_name.clone()), + None, ); yield Ok::(event); } else { @@ -1407,7 +1441,7 @@ pub(crate) async fn chat_completions( response_as_tool, system_fingerprint.clone(), model_id.clone(), - Some(global_function_name.clone()), + None, ); yield Ok::(event); }