mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
fix: adjust stream, improve tests and add openai client test
This commit is contained in:
parent
07c20903e5
commit
40f905d00b
@ -67,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel):
|
|||||||
class ChoiceDelta(BaseModel):
|
class ChoiceDelta(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
tool_calls: Optional[ChoiceDeltaToolCall] = None
|
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class Choice(BaseModel):
|
class Choice(BaseModel):
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
"ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='', function=ChoiceDeltaToolCallFunction(arguments='\"}', name='get_current_weather'), type='function')]), finish_reason=None, index=0, logprobs=None)], created=1739798781, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion.chunk', service_tier=None, system_fingerprint='3.1.1-dev0-native', usage=None)"
|
@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"arguments": "\"}",
|
||||||
|
"name": "get_current_weather"
|
||||||
|
},
|
||||||
|
"id": "",
|
||||||
|
"index": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1739799458,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
@ -5,7 +5,7 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "I am an AI assistant",
|
"content": "I am a helpful assistant!",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
@ -13,14 +13,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1728497062,
|
"created": 1739357385,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "2.4.2-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 23,
|
"completion_tokens": 23,
|
||||||
"prompt_tokens": 604,
|
"prompt_tokens": 494,
|
||||||
"total_tokens": 627
|
"total_tokens": 517
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1728497531,
|
"created": 1739441937,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.4.2-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": " fans",
|
"content": " Oracle",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
},
|
},
|
||||||
@ -11,10 +11,10 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1728497461,
|
"created": 1739444803,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.4.2-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
@ -3,25 +3,27 @@
|
|||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": {
|
"tool_calls": [
|
||||||
"function": {
|
{
|
||||||
"arguments": "<|eot_id|>",
|
"function": {
|
||||||
"name": null
|
"arguments": "}",
|
||||||
},
|
"name": "get_n_day_weather_forecast"
|
||||||
"id": "",
|
},
|
||||||
"index": 0,
|
"id": "",
|
||||||
"type": "function"
|
"index": 0,
|
||||||
}
|
"type": "function"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"finish_reason": "stop",
|
"finish_reason": null,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1732293254,
|
"created": 1739797595,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.4.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1729262528,
|
"created": 1739454835,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.3.2-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
@ -4,25 +4,27 @@
|
|||||||
"delta": {
|
"delta": {
|
||||||
"content": null,
|
"content": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": {
|
"tool_calls": [
|
||||||
"function": {
|
{
|
||||||
"arguments": "<|eot_id|>",
|
"function": {
|
||||||
"name": null
|
"arguments": "}",
|
||||||
},
|
"name": "get_n_day_weather_forecast"
|
||||||
"id": "",
|
},
|
||||||
"index": 0,
|
"id": "",
|
||||||
"type": "function"
|
"index": 0,
|
||||||
}
|
"type": "function"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"finish_reason": "stop",
|
"finish_reason": null,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1732293246,
|
"created": 1739456930,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.4.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
@ -4,25 +4,27 @@
|
|||||||
"delta": {
|
"delta": {
|
||||||
"content": null,
|
"content": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": {
|
"tool_calls": [
|
||||||
"function": {
|
{
|
||||||
"arguments": "<|eot_id|>",
|
"function": {
|
||||||
"name": null
|
"arguments": "\"}",
|
||||||
},
|
"name": "get_current_weather"
|
||||||
"id": "",
|
},
|
||||||
"index": 0,
|
"id": "",
|
||||||
"type": "function"
|
"index": 0,
|
||||||
}
|
"type": "function"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"finish_reason": "stop",
|
"finish_reason": null,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1732293235,
|
"created": 1739367874,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.4.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
110
integration-tests/models/test_openai_llama_tools.py
Normal file
110
integration-tests/models/test_openai_llama_tools.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def openai_llama_tools_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
num_shard=2,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def openai_llama_tools(openai_llama_tools_handle):
|
||||||
|
await openai_llama_tools_handle.health(300)
|
||||||
|
return openai_llama_tools_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_n_day_weather_forecast",
|
||||||
|
"description": "Get an N-day weather forecast",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
"num_days": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of days to forecast",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format", "num_days"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_openai_llama_tools(openai_llama_tools, response_snapshot):
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=f"{openai_llama_tools.base_url}/v1",
|
||||||
|
api_key="_",
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="get_current_weather",
|
||||||
|
max_tokens=500,
|
||||||
|
stream=True,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call_string = ""
|
||||||
|
for chunk in chat_completion:
|
||||||
|
tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||||
|
last_chunk = chunk.to_dict()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
tool_call_string == '{ "location": "San Francisco, CA", "format": "fahrenheit"}'
|
||||||
|
)
|
||||||
|
assert last_chunk == response_snapshot
|
@ -5,11 +5,7 @@ import json
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_grammar_tools_handle(launcher):
|
def flash_llama_grammar_tools_handle(launcher):
|
||||||
with launcher(
|
with launcher("meta-llama/Meta-Llama-3.1-8B-Instruct") as handle:
|
||||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
||||||
num_shard=2,
|
|
||||||
disable_grammar_support=False,
|
|
||||||
) as handle:
|
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@ -101,7 +97,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -138,7 +134,7 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -176,7 +172,7 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
"arguments": '{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -213,15 +209,14 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
last_response = None
|
last_response = None
|
||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
tool_calls_generated += (
|
||||||
|
response.choices[0].delta.tool_calls[0].function.arguments
|
||||||
|
)
|
||||||
last_response = response
|
last_response = response
|
||||||
assert response.choices[0].delta.content is None
|
assert response.choices[0].delta.content is None
|
||||||
|
|
||||||
assert (
|
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
|
||||||
tool_calls_generated
|
assert count == 16
|
||||||
== '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
|
|
||||||
)
|
|
||||||
assert count == 28
|
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -249,7 +244,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert responses.choices[0].message.tool_calls is None
|
assert responses.choices[0].message.tool_calls is None
|
||||||
assert responses.choices[0].message.content == "I am an AI assistant"
|
assert responses.choices[0].message.content == "I am a helpful assistant!"
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
||||||
@ -287,7 +282,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
|||||||
assert response.choices[0].delta.tool_calls is None
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
assert count == 5
|
assert count == 5
|
||||||
assert content_generated == "I am an AI assistant"
|
assert content_generated == "I am a helpful assistant"
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -323,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
|||||||
last_response = response
|
last_response = response
|
||||||
assert response.choices[0].delta.tool_calls is None
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
assert count == 62
|
assert count == 77
|
||||||
assert (
|
assert (
|
||||||
content_generated
|
content_generated
|
||||||
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
@ -360,13 +355,15 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
|||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
assert response.choices[0].delta.content is None
|
assert response.choices[0].delta.content is None
|
||||||
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
tool_calls_generated += (
|
||||||
|
response.choices[0].delta.tool_calls[0].function.arguments
|
||||||
|
)
|
||||||
last_response = response
|
last_response = response
|
||||||
|
|
||||||
assert count == 29
|
assert count == 23
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>'
|
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
@ -457,15 +454,14 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
|||||||
if line == "[DONE]":
|
if line == "[DONE]":
|
||||||
break
|
break
|
||||||
response = json.loads(line)
|
response = json.loads(line)
|
||||||
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
|
tool_call = response["choices"][0]["delta"]["tool_calls"][0]
|
||||||
"function"
|
tool_calls_generated += tool_call["function"]["arguments"]
|
||||||
]["arguments"]
|
|
||||||
last_response = response
|
last_response = response
|
||||||
|
|
||||||
assert count == 39
|
assert count == 25
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
@ -1142,9 +1142,7 @@ fn create_event_from_stream_token(
|
|||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if inner_using_tools {
|
let (content, tool_calls) = if inner_using_tools {
|
||||||
// escape the token text so its a json string
|
(None, Some(vec![stream_token.token.text.clone()]))
|
||||||
let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#);
|
|
||||||
(None, Some(vec![escaped_text]))
|
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
Some(stream_token.token.text.clone())
|
Some(stream_token.token.text.clone())
|
||||||
@ -1307,7 +1305,8 @@ pub(crate) async fn chat_completions(
|
|||||||
state = StreamState::Content {
|
state = StreamState::Content {
|
||||||
skip_close_quote: false,
|
skip_close_quote: false,
|
||||||
};
|
};
|
||||||
buffer = buffer.drain(0..1).collect();
|
buffer.drain(1..); // only keep the first token (opening '{')
|
||||||
|
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1361,40 +1360,59 @@ pub(crate) async fn chat_completions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer.push(stream_token);
|
buffer.push(stream_token);
|
||||||
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
if buffer.len() > 1 {
|
||||||
for stream_token in &buffer[..buffer.len() - 2] {
|
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
||||||
let event = create_event_from_stream_token(
|
for stream_token in &buffer[..buffer.len() - 2] {
|
||||||
stream_token,
|
let event = create_event_from_stream_token(
|
||||||
logprobs,
|
stream_token,
|
||||||
stream_options.clone(),
|
logprobs,
|
||||||
response_as_tool,
|
stream_options.clone(),
|
||||||
system_fingerprint.clone(),
|
response_as_tool,
|
||||||
model_id.clone(),
|
system_fingerprint.clone(),
|
||||||
Some(global_function_name.clone()),
|
model_id.clone(),
|
||||||
);
|
Some(global_function_name.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
|
buffer = buffer.drain(buffer.len() - 2..).collect();
|
||||||
}
|
}
|
||||||
buffer = buffer.drain(buffer.len() - 2..).collect();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => yield Ok(err.into_openai_event())
|
Err(err) => yield Ok(err.into_openai_event())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// send the second to last stream token but remove the trailing '}' if it exists
|
if response_as_tool {
|
||||||
let mut closing_stream_token = buffer.remove(0);
|
// send the second to last stream token but remove the trailing '}' if it exists
|
||||||
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
|
let mut closing_stream_token = buffer.remove(0);
|
||||||
let event = create_event_from_stream_token(
|
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
|
||||||
&closing_stream_token,
|
let event = create_event_from_stream_token(
|
||||||
logprobs,
|
&closing_stream_token,
|
||||||
stream_options.clone(),
|
logprobs,
|
||||||
response_as_tool,
|
stream_options.clone(),
|
||||||
system_fingerprint.clone(),
|
response_as_tool,
|
||||||
model_id.clone(),
|
system_fingerprint.clone(),
|
||||||
Some(global_function_name.clone()),
|
model_id.clone(),
|
||||||
);
|
Some(global_function_name.clone()),
|
||||||
yield Ok::<Event, Infallible>(event);
|
);
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
} else {
|
||||||
|
// send each buffer element
|
||||||
|
for stream_token in buffer {
|
||||||
|
let event = create_event_from_stream_token(
|
||||||
|
&stream_token,
|
||||||
|
logprobs,
|
||||||
|
stream_options.clone(),
|
||||||
|
response_as_tool,
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
model_id.clone(),
|
||||||
|
Some(global_function_name.clone()),
|
||||||
|
);
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user