diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 1085075e..6f51c153 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -67,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str content: Optional[str] = None - tool_calls: Optional[ChoiceDeltaToolCall] = None + tool_calls: Optional[List[ChoiceDeltaToolCall]] = None class Choice(BaseModel): diff --git a/integration-tests/models/__snapshots__/test_openai_integration/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_openai_integration/test_flash_llama_grammar_tools.json new file mode 100644 index 00000000..315969ad --- /dev/null +++ b/integration-tests/models/__snapshots__/test_openai_integration/test_flash_llama_grammar_tools.json @@ -0,0 +1 @@ +"ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='', function=ChoiceDeltaToolCallFunction(arguments='\"}', name='get_current_weather'), type='function')]), finish_reason=None, index=0, logprobs=None)], created=1739798781, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion.chunk', service_tier=None, system_fingerprint='3.1.1-dev0-native', usage=None)" diff --git a/integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json b/integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json new file mode 100644 index 00000000..764e9946 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_openai_llama_tools/test_openai_llama_tools.json @@ -0,0 +1,29 @@ +{ + "choices": [ + { + "delta": { + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": "\"}", + "name": "get_current_weather" + }, + "id": "", + "index": 0, + "type": "function" + } + ] + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1739799458, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "3.1.1-dev0-native", + "usage": null +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json index 3ed893fa..f858c7d8 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "I am an AI assistant", + "content": "I am a helpful assistant!", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1728497062, + "created": 1739357385, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion", - "system_fingerprint": "2.4.2-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": { "completion_tokens": 23, - "prompt_tokens": 604, - "total_tokens": 627 + "prompt_tokens": 494, + "total_tokens": 517 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json index b134004a..862f6cf3 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json @@ -11,10 +11,10 @@ "logprobs": null } ], - "created": 1728497531, + "created": 1739441937, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.4.2-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json index 1362b472..b8da13ab 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json @@ -2,7 +2,7 @@ "choices": [ { "delta": { - "content": " fans", + "content": " Oracle", "role": "assistant", "tool_calls": null }, @@ -11,10 +11,10 @@ "logprobs": null } ], - "created": 1728497461, + "created": 1739444803, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.4.2-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json index bb8d61c8..7478c079 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -3,25 +3,27 @@ { "delta": { "role": "assistant", - "tool_calls": { - "function": { - "arguments": "<|eot_id|>", - "name": null - }, - "id": "", - "index": 0, - "type": "function" - } + "tool_calls": [ + { + "function": { + "arguments": "}", + "name": "get_n_day_weather_forecast" + }, + "id": "", + "index": 0, + "type": "function" + } + ] }, - "finish_reason": "stop", + "finish_reason": null, "index": 0, "logprobs": null } ], - "created": 1732293254, + "created": 1739797595, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.4.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json index 2ccab4a9..b77b1022 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json @@ -11,10 +11,10 @@ "logprobs": null } ], - "created": 1729262528, + "created": 1739454835, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.3.2-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json index dbced5b8..33b775e5 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -4,25 +4,27 @@ "delta": { "content": null, "role": "assistant", - "tool_calls": { - "function": { - "arguments": "<|eot_id|>", - "name": null - }, - "id": "", - "index": 0, - "type": "function" - } + "tool_calls": [ + { + "function": { + "arguments": "}", + "name": "get_n_day_weather_forecast" + }, + "id": "", + "index": 0, + "type": "function" + } + ] }, - "finish_reason": "stop", + "finish_reason": null, "index": 0, "logprobs": null } ], - "created": 1732293246, + "created": 1739456930, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.4.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json index 27d2f9ca..211f6a42 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -4,25 +4,27 @@ "delta": { "content": null, "role": "assistant", - "tool_calls": { - "function": { - "arguments": "<|eot_id|>", - "name": null - }, - "id": "", - "index": 0, - "type": "function" - } + "tool_calls": [ + { + "function": { + "arguments": "\"}", + "name": "get_current_weather" + }, + "id": "", + "index": 0, + "type": "function" + } + ] }, - "finish_reason": "stop", + "finish_reason": null, "index": 0, "logprobs": null } ], - "created": 1732293235, + "created": 1739367874, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.4.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null } diff --git a/integration-tests/models/test_openai_llama_tools.py b/integration-tests/models/test_openai_llama_tools.py new file mode 100644 index 00000000..051d4089 --- /dev/null +++ b/integration-tests/models/test_openai_llama_tools.py @@ -0,0 +1,110 @@ +from openai import OpenAI +import pytest + + +@pytest.fixture(scope="module") +def openai_llama_tools_handle(launcher): + with launcher( + "meta-llama/Meta-Llama-3.1-8B-Instruct", + num_shard=2, + disable_grammar_support=False, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def openai_llama_tools(openai_llama_tools_handle): + await openai_llama_tools_handle.health(300) + return openai_llama_tools_handle.client + + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + "additionalProperties": False, + }, + }, + }, +] + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_openai_llama_tools(openai_llama_tools, response_snapshot): + client = OpenAI( + base_url=f"{openai_llama_tools.base_url}/v1", + api_key="_", + ) + + chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + { + "role": "system", + "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + tools=tools, + tool_choice="get_current_weather", + max_tokens=500, + stream=True, + seed=42, + ) + + tool_call_string = "" + for chunk in chat_completion: + tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments + last_chunk = chunk.to_dict() + + assert ( + tool_call_string == '{ "location": "San Francisco, CA", "format": "fahrenheit"}' + ) + assert last_chunk == response_snapshot diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index b8a90cff..68f6bfaa 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -5,11 +5,7 @@ import json @pytest.fixture(scope="module") def flash_llama_grammar_tools_handle(launcher): - with launcher( - "meta-llama/Meta-Llama-3.1-8B-Instruct", - num_shard=2, - disable_grammar_support=False, - ) as handle: + with launcher("meta-llama/Meta-Llama-3.1-8B-Instruct") as handle: yield handle @@ -101,7 +97,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, + "arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}', }, } ] @@ -138,7 +134,7 @@ async def test_flash_llama_grammar_tools_auto( "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, + "arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}', }, } ] @@ -176,7 +172,7 @@ async def test_flash_llama_grammar_tools_choice( "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, + "arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}', }, } ] @@ -213,15 +209,14 @@ async def test_flash_llama_grammar_tools_stream( last_response = None async for response in responses: count += 1 - tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + tool_calls_generated += ( + response.choices[0].delta.tool_calls[0].function.arguments + ) last_response = response assert response.choices[0].delta.content is None - assert ( - tool_calls_generated - == '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>' - ) - assert count == 28 + assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}' + assert count == 16 assert last_response == response_snapshot @@ -249,7 +244,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( ) assert responses.choices[0].message.tool_calls is None - assert responses.choices[0].message.content == "I am an AI assistant" + assert responses.choices[0].message.content == "I am a helpful assistant!" assert responses == response_snapshot @@ -287,7 +282,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream( assert response.choices[0].delta.tool_calls is None assert count == 5 - assert content_generated == "I am an AI assistant" + assert content_generated == "I am a helpful assistant" assert last_response == response_snapshot @@ -323,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream( last_response = response assert response.choices[0].delta.tool_calls is None - assert count == 62 + assert count == 77 assert ( content_generated - == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans" + == "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle" ) assert last_response == response_snapshot @@ -360,13 +355,15 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required( async for response in responses: count += 1 assert response.choices[0].delta.content is None - tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + tool_calls_generated += ( + response.choices[0].delta.tool_calls[0].function.arguments + ) last_response = response - assert count == 29 + assert count == 23 assert ( tool_calls_generated - == '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>' + == '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}' ) assert last_response == response_snapshot @@ -457,15 +454,14 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( if line == "[DONE]": break response = json.loads(line) - tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][ - "function" - ]["arguments"] + tool_call = response["choices"][0]["delta"]["tool_calls"][0] + tool_calls_generated += tool_call["function"]["arguments"] last_response = response - assert count == 39 + assert count == 25 assert ( tool_calls_generated - == '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>' + == '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}' ) assert last_response == response_snapshot diff --git a/router/src/server.rs b/router/src/server.rs index e96d9bf5..f94046a8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1142,9 +1142,7 @@ fn create_event_from_stream_token( // replace the content with the tool calls if grammar is present let (content, tool_calls) = if inner_using_tools { - // escape the token text so its a json string - let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#); - (None, Some(vec![escaped_text])) + (None, Some(vec![stream_token.token.text.clone()])) } else { let content = if !stream_token.token.special { Some(stream_token.token.text.clone()) @@ -1307,7 +1305,8 @@ pub(crate) async fn chat_completions( state = StreamState::Content { skip_close_quote: false, }; - buffer = buffer.drain(0..1).collect(); + buffer.drain(1..); // only keep the first token (opening '{') + buffer[0].token.text = buffer[0].token.text.chars().take(1).collect(); } } } @@ -1361,40 +1360,59 @@ pub(crate) async fn chat_completions( } buffer.push(stream_token); - // FIFO send the buffer but left the last two elements (closing '}' and EOS token) - for stream_token in &buffer[..buffer.len() - 2] { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - Some(global_function_name.clone()), - ); + if buffer.len() > 1 { + // FIFO send the buffer but left the last two elements (closing '}' and EOS token) + for stream_token in &buffer[..buffer.len() - 2] { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + Some(global_function_name.clone()), + ); - yield Ok::(event); + yield Ok::(event); + } + buffer = buffer.drain(buffer.len() - 2..).collect(); } - buffer = buffer.drain(buffer.len() - 2..).collect(); } } } Err(err) => yield Ok(err.into_openai_event()) } } - // send the second to last stream token but remove the trailing '}' if it exists - let mut closing_stream_token = buffer.remove(0); - closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string(); - let event = create_event_from_stream_token( - &closing_stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - Some(global_function_name.clone()), - ); - yield Ok::(event); + if response_as_tool { + // send the second to last stream token but remove the trailing '}' if it exists + let mut closing_stream_token = buffer.remove(0); + closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string(); + let event = create_event_from_stream_token( + &closing_stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + Some(global_function_name.clone()), + ); + yield Ok::(event); + } else { + // send each buffer element + for stream_token in buffer { + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + Some(global_function_name.clone()), + ); + yield Ok::(event); + } + } + yield Ok::(Event::default().data("[DONE]")); };