mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
fix: adjust stream, improve tests and add openai client test
This commit is contained in:
parent
07c20903e5
commit
40f905d00b
@ -67,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel):
|
||||
class ChoiceDelta(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[ChoiceDeltaToolCall] = None
|
||||
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
|
@ -0,0 +1 @@
|
||||
"ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='', function=ChoiceDeltaToolCallFunction(arguments='\"}', name='get_current_weather'), type='function')]), finish_reason=None, index=0, logprobs=None)], created=1739798781, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion.chunk', service_tier=None, system_fingerprint='3.1.1-dev0-native', usage=None)"
|
@ -0,0 +1,29 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\"}",
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739799458,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
@ -5,7 +5,7 @@
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I am an AI assistant",
|
||||
"content": "I am a helpful assistant!",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
@ -13,14 +13,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1728497062,
|
||||
"created": 1739357385,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.4.2-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 23,
|
||||
"prompt_tokens": 604,
|
||||
"total_tokens": 627
|
||||
"prompt_tokens": 494,
|
||||
"total_tokens": 517
|
||||
}
|
||||
}
|
||||
|
@ -11,10 +11,10 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1728497531,
|
||||
"created": 1739441937,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.4.2-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " fans",
|
||||
"content": " Oracle",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
@ -11,10 +11,10 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1728497461,
|
||||
"created": 1739444803,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.4.2-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
||||
|
@ -3,25 +3,27 @@
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "<|eot_id|>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "}",
|
||||
"name": "get_n_day_weather_forecast"
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1732293254,
|
||||
"created": 1739797595,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.4.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
||||
|
@ -11,10 +11,10 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1729262528,
|
||||
"created": 1739454835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
||||
|
@ -4,25 +4,27 @@
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "<|eot_id|>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "}",
|
||||
"name": "get_n_day_weather_forecast"
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1732293246,
|
||||
"created": 1739456930,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.4.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
||||
|
@ -4,25 +4,27 @@
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "<|eot_id|>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\"}",
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1732293235,
|
||||
"created": 1739367874,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.4.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
||||
|
110
integration-tests/models/test_openai_llama_tools.py
Normal file
110
integration-tests/models/test_openai_llama_tools.py
Normal file
@ -0,0 +1,110 @@
|
||||
from openai import OpenAI
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def openai_llama_tools_handle(launcher):
|
||||
with launcher(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
num_shard=2,
|
||||
disable_grammar_support=False,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def openai_llama_tools(openai_llama_tools_handle):
|
||||
await openai_llama_tools_handle.health(300)
|
||||
return openai_llama_tools_handle.client
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_openai_llama_tools(openai_llama_tools, response_snapshot):
|
||||
client = OpenAI(
|
||||
base_url=f"{openai_llama_tools.base_url}/v1",
|
||||
api_key="_",
|
||||
)
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="get_current_weather",
|
||||
max_tokens=500,
|
||||
stream=True,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
tool_call_string = ""
|
||||
for chunk in chat_completion:
|
||||
tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
last_chunk = chunk.to_dict()
|
||||
|
||||
assert (
|
||||
tool_call_string == '{ "location": "San Francisco, CA", "format": "fahrenheit"}'
|
||||
)
|
||||
assert last_chunk == response_snapshot
|
@ -5,11 +5,7 @@ import json
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_tools_handle(launcher):
|
||||
with launcher(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
num_shard=2,
|
||||
disable_grammar_support=False,
|
||||
) as handle:
|
||||
with launcher("meta-llama/Meta-Llama-3.1-8B-Instruct") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@ -101,7 +97,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
||||
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -138,7 +134,7 @@ async def test_flash_llama_grammar_tools_auto(
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
||||
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -176,7 +172,7 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
||||
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -213,15 +209,14 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||
tool_calls_generated += (
|
||||
response.choices[0].delta.tool_calls[0].function.arguments
|
||||
)
|
||||
last_response = response
|
||||
assert response.choices[0].delta.content is None
|
||||
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
|
||||
)
|
||||
assert count == 28
|
||||
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
|
||||
assert count == 16
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@ -249,7 +244,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.tool_calls is None
|
||||
assert responses.choices[0].message.content == "I am an AI assistant"
|
||||
assert responses.choices[0].message.content == "I am a helpful assistant!"
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
@ -287,7 +282,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 5
|
||||
assert content_generated == "I am an AI assistant"
|
||||
assert content_generated == "I am a helpful assistant"
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@ -323,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 62
|
||||
assert count == 77
|
||||
assert (
|
||||
content_generated
|
||||
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
||||
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
@ -360,13 +355,15 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
||||
async for response in responses:
|
||||
count += 1
|
||||
assert response.choices[0].delta.content is None
|
||||
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||
tool_calls_generated += (
|
||||
response.choices[0].delta.tool_calls[0].function.arguments
|
||||
)
|
||||
last_response = response
|
||||
|
||||
assert count == 29
|
||||
assert count == 23
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>'
|
||||
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
@ -457,15 +454,14 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||
if line == "[DONE]":
|
||||
break
|
||||
response = json.loads(line)
|
||||
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
|
||||
"function"
|
||||
]["arguments"]
|
||||
tool_call = response["choices"][0]["delta"]["tool_calls"][0]
|
||||
tool_calls_generated += tool_call["function"]["arguments"]
|
||||
last_response = response
|
||||
|
||||
assert count == 39
|
||||
assert count == 25
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
||||
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
@ -1142,9 +1142,7 @@ fn create_event_from_stream_token(
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
// escape the token text so its a json string
|
||||
let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#);
|
||||
(None, Some(vec![escaped_text]))
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
@ -1307,7 +1305,8 @@ pub(crate) async fn chat_completions(
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
buffer = buffer.drain(0..1).collect();
|
||||
buffer.drain(1..); // only keep the first token (opening '{')
|
||||
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1361,40 +1360,59 @@ pub(crate) async fn chat_completions(
|
||||
}
|
||||
|
||||
buffer.push(stream_token);
|
||||
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
||||
for stream_token in &buffer[..buffer.len() - 2] {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
);
|
||||
if buffer.len() > 1 {
|
||||
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
||||
for stream_token in &buffer[..buffer.len() - 2] {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
buffer = buffer.drain(buffer.len() - 2..).collect();
|
||||
}
|
||||
buffer = buffer.drain(buffer.len() - 2..).collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => yield Ok(err.into_openai_event())
|
||||
}
|
||||
}
|
||||
// send the second to last stream token but remove the trailing '}' if it exists
|
||||
let mut closing_stream_token = buffer.remove(0);
|
||||
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
|
||||
let event = create_event_from_stream_token(
|
||||
&closing_stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
if response_as_tool {
|
||||
// send the second to last stream token but remove the trailing '}' if it exists
|
||||
let mut closing_stream_token = buffer.remove(0);
|
||||
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
|
||||
let event = create_event_from_stream_token(
|
||||
&closing_stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
} else {
|
||||
// send each buffer element
|
||||
for stream_token in buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
|
||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user