fix: adjust stream, improve tests and add openai client test

This commit is contained in:
drbh 2025-02-17 13:38:49 +00:00
parent 07c20903e5
commit 40f905d00b
13 changed files with 265 additions and 105 deletions

View File

@ -67,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall] = None
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
class Choice(BaseModel):

View File

@ -0,0 +1 @@
"ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='', function=ChoiceDeltaToolCallFunction(arguments='\"}', name='get_current_weather'), type='function')]), finish_reason=None, index=0, logprobs=None)], created=1739798781, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion.chunk', service_tier=None, system_fingerprint='3.1.1-dev0-native', usage=None)"

View File

@ -0,0 +1,29 @@
{
"choices": [
{
"delta": {
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "\"}",
"name": "get_current_weather"
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1739799458,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
}

View File

@ -5,7 +5,7 @@
"index": 0,
"logprobs": null,
"message": {
"content": "I am an AI assistant",
"content": "I am a helpful assistant!",
"name": null,
"role": "assistant",
"tool_calls": null
@ -13,14 +13,14 @@
"usage": null
}
],
"created": 1728497062,
"created": 1739357385,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 23,
"prompt_tokens": 604,
"total_tokens": 627
"prompt_tokens": 494,
"total_tokens": 517
}
}

View File

@ -11,10 +11,10 @@
"logprobs": null
}
],
"created": 1728497531,
"created": 1739441937,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.2-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
}

View File

@ -2,7 +2,7 @@
"choices": [
{
"delta": {
"content": " fans",
"content": " Oracle",
"role": "assistant",
"tool_calls": null
},
@ -11,10 +11,10 @@
"logprobs": null
}
],
"created": 1728497461,
"created": 1739444803,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.2-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
}

View File

@ -3,25 +3,27 @@
{
"delta": {
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"arguments": "}",
"name": "get_n_day_weather_forecast"
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": "stop",
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1732293254,
"created": 1739797595,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
}

View File

@ -11,10 +11,10 @@
"logprobs": null
}
],
"created": 1729262528,
"created": 1739454835,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
}

View File

@ -4,25 +4,27 @@
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"arguments": "}",
"name": "get_n_day_weather_forecast"
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": "stop",
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1732293246,
"created": 1739456930,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
}

View File

@ -4,25 +4,27 @@
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"arguments": "\"}",
"name": "get_current_weather"
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": "stop",
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1732293235,
"created": 1739367874,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
}

View File

@ -0,0 +1,110 @@
from openai import OpenAI
import pytest
@pytest.fixture(scope="module")
def openai_llama_tools_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
num_shard=2,
disable_grammar_support=False,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def openai_llama_tools(openai_llama_tools_handle):
await openai_llama_tools_handle.health(300)
return openai_llama_tools_handle.client
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast",
},
},
"required": ["location", "format", "num_days"],
"additionalProperties": False,
},
},
},
]
@pytest.mark.asyncio
@pytest.mark.private
async def test_openai_llama_tools(openai_llama_tools, response_snapshot):
client = OpenAI(
base_url=f"{openai_llama_tools.base_url}/v1",
api_key="_",
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{
"role": "system",
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
tools=tools,
tool_choice="get_current_weather",
max_tokens=500,
stream=True,
seed=42,
)
tool_call_string = ""
for chunk in chat_completion:
tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments
last_chunk = chunk.to_dict()
assert (
tool_call_string == '{ "location": "San Francisco, CA", "format": "fahrenheit"}'
)
assert last_chunk == response_snapshot

View File

@ -5,11 +5,7 @@ import json
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
num_shard=2,
disable_grammar_support=False,
) as handle:
with launcher("meta-llama/Meta-Llama-3.1-8B-Instruct") as handle:
yield handle
@ -101,7 +97,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
},
}
]
@ -138,7 +134,7 @@ async def test_flash_llama_grammar_tools_auto(
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
},
}
]
@ -176,7 +172,7 @@ async def test_flash_llama_grammar_tools_choice(
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
},
}
]
@ -213,15 +209,14 @@ async def test_flash_llama_grammar_tools_stream(
last_response = None
async for response in responses:
count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
tool_calls_generated += (
response.choices[0].delta.tool_calls[0].function.arguments
)
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
)
assert count == 28
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
assert count == 16
assert last_response == response_snapshot
@ -249,7 +244,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
)
assert responses.choices[0].message.tool_calls is None
assert responses.choices[0].message.content == "I am an AI assistant"
assert responses.choices[0].message.content == "I am a helpful assistant!"
assert responses == response_snapshot
@ -287,7 +282,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am an AI assistant"
assert content_generated == "I am a helpful assistant"
assert last_response == response_snapshot
@ -323,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 62
assert count == 77
assert (
content_generated
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
)
assert last_response == response_snapshot
@ -360,13 +355,15 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
async for response in responses:
count += 1
assert response.choices[0].delta.content is None
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
tool_calls_generated += (
response.choices[0].delta.tool_calls[0].function.arguments
)
last_response = response
assert count == 29
assert count == 23
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>'
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
)
assert last_response == response_snapshot
@ -457,15 +454,14 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
if line == "[DONE]":
break
response = json.loads(line)
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
"function"
]["arguments"]
tool_call = response["choices"][0]["delta"]["tool_calls"][0]
tool_calls_generated += tool_call["function"]["arguments"]
last_response = response
assert count == 39
assert count == 25
assert (
tool_calls_generated
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
)
assert last_response == response_snapshot

View File

@ -1142,9 +1142,7 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
// escape the token text so its a json string
let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#);
(None, Some(vec![escaped_text]))
(None, Some(vec![stream_token.token.text.clone()]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
@ -1307,7 +1305,8 @@ pub(crate) async fn chat_completions(
state = StreamState::Content {
skip_close_quote: false,
};
buffer = buffer.drain(0..1).collect();
buffer.drain(1..); // only keep the first token (opening '{')
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
}
}
}
@ -1361,40 +1360,59 @@ pub(crate) async fn chat_completions(
}
buffer.push(stream_token);
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
for stream_token in &buffer[..buffer.len() - 2] {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
);
if buffer.len() > 1 {
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
for stream_token in &buffer[..buffer.len() - 2] {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
);
yield Ok::<Event, Infallible>(event);
yield Ok::<Event, Infallible>(event);
}
buffer = buffer.drain(buffer.len() - 2..).collect();
}
buffer = buffer.drain(buffer.len() - 2..).collect();
}
}
}
Err(err) => yield Ok(err.into_openai_event())
}
}
// send the second to last stream token but remove the trailing '}' if it exists
let mut closing_stream_token = buffer.remove(0);
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
let event = create_event_from_stream_token(
&closing_stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
);
yield Ok::<Event, Infallible>(event);
if response_as_tool {
// send the second to last stream token but remove the trailing '}' if it exists
let mut closing_stream_token = buffer.remove(0);
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
let event = create_event_from_stream_token(
&closing_stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
);
yield Ok::<Event, Infallible>(event);
} else {
// send each buffer element
for stream_token in buffer {
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
);
yield Ok::<Event, Infallible>(event);
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
};