mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Fix tool call4 (#3094)
* Removing the no_tool content information. * Removing a lot of NO_TOOL shenanigans. * Update the tests.
This commit is contained in:
parent
ed46c2c414
commit
5c5528e362
@ -5,20 +5,20 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "I am a helpful assistant!",
|
"content": "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741263686,
|
"created": 1741693957,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "3.1.2-dev0-native",
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 23,
|
"completion_tokens": 12,
|
||||||
"prompt_tokens": 494,
|
"prompt_tokens": 53,
|
||||||
"total_tokens": 517
|
"total_tokens": 65
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,24 +1,4 @@
|
|||||||
[
|
[
|
||||||
{
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"delta": {
|
|
||||||
"content": "",
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": null
|
|
||||||
},
|
|
||||||
"finish_reason": null,
|
|
||||||
"index": 0,
|
|
||||||
"logprobs": null
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"created": 1741364571,
|
|
||||||
"id": "",
|
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"system_fingerprint": "3.1.2-dev0-native",
|
|
||||||
"usage": null
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
@ -32,7 +12,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741364571,
|
"created": 1741694017,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -43,7 +23,7 @@
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": " am",
|
"content": "'m",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
},
|
},
|
||||||
@ -52,7 +32,127 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741364571,
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " an",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " artificial",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " intelligence",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " model",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " known",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " as",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -72,7 +172,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741364571,
|
"created": 1741694017,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -83,7 +183,7 @@
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": " helpful",
|
"content": " large",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
},
|
},
|
||||||
@ -92,7 +192,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741364571,
|
"created": 1741694017,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -103,7 +203,7 @@
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": " assistant",
|
"content": " language",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
},
|
},
|
||||||
@ -112,7 +212,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741364571,
|
"created": 1741694017,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -123,7 +223,7 @@
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": "!",
|
"content": " model",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
},
|
},
|
||||||
@ -132,7 +232,167 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741364571,
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " (",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "LL",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "M",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": ")",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " or",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " convers",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "ational",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " AI",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741694017,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -279,7 +279,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_nostream(
|
|||||||
):
|
):
|
||||||
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
||||||
response = client.chat_completion(
|
response = client.chat_completion(
|
||||||
max_tokens=100,
|
max_tokens=20,
|
||||||
seed=24,
|
seed=24,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
@ -299,7 +299,10 @@ async def test_flash_llama_grammar_tools_insufficient_information_nostream(
|
|||||||
content_generated = response.choices[0].message.content
|
content_generated = response.choices[0].message.content
|
||||||
assert response.choices[0].message.tool_calls is None
|
assert response.choices[0].message.tool_calls is None
|
||||||
|
|
||||||
assert content_generated == "I am a helpful assistant!"
|
assert (
|
||||||
|
content_generated
|
||||||
|
== "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI"
|
||||||
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -310,7 +313,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
|||||||
):
|
):
|
||||||
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
||||||
stream = client.chat_completion(
|
stream = client.chat_completion(
|
||||||
max_tokens=100,
|
max_tokens=20,
|
||||||
seed=24,
|
seed=24,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
@ -335,7 +338,10 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
|||||||
assert chunk.choices[0].delta.tool_calls is None
|
assert chunk.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
######## This is exactly the same as the non streaming case
|
######## This is exactly the same as the non streaming case
|
||||||
assert content_generated == "I am a helpful assistant!"
|
assert (
|
||||||
|
content_generated
|
||||||
|
== "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI"
|
||||||
|
)
|
||||||
assert chunks == response_snapshot
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -346,7 +352,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_auto(
|
|||||||
):
|
):
|
||||||
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
||||||
stream = client.chat_completion(
|
stream = client.chat_completion(
|
||||||
max_tokens=100,
|
max_tokens=20,
|
||||||
seed=24,
|
seed=24,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
@ -372,7 +378,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_auto(
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
content_generated
|
content_generated
|
||||||
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle."
|
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish,"
|
||||||
)
|
)
|
||||||
assert chunks == response_snapshot
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
|
@ -6,22 +6,6 @@ use crate::{
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
enum _NoTool {
|
|
||||||
NoTool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct NoToolCall {
|
|
||||||
_name: _NoTool,
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct NoTool {
|
|
||||||
function: NoToolCall,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ToolCall {
|
struct ToolCall {
|
||||||
_name: String,
|
_name: String,
|
||||||
@ -34,9 +18,19 @@ struct Call {
|
|||||||
function: ToolCall,
|
function: ToolCall,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn parse_output(
|
#[cfg_attr(test, derive(Debug))]
|
||||||
generated_text: &str,
|
pub(crate) enum ChatEvent {
|
||||||
) -> Result<(Option<Vec<crate::ToolCall>>, Option<String>), InferError> {
|
NoTool,
|
||||||
|
Events(Vec<CompletionType>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(test, derive(Debug))]
|
||||||
|
pub(crate) enum ChatChoice {
|
||||||
|
NoTool,
|
||||||
|
ToolCalls(Vec<crate::ToolCall>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {
|
||||||
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
|
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
|
||||||
InferError::ToolError(format!(
|
InferError::ToolError(format!(
|
||||||
"Failed to parse generated text: {} {:?}",
|
"Failed to parse generated text: {} {:?}",
|
||||||
@ -48,16 +42,7 @@ pub(crate) fn parse_output(
|
|||||||
match &name[..] {
|
match &name[..] {
|
||||||
"no_tool" => {
|
"no_tool" => {
|
||||||
// parse the content message
|
// parse the content message
|
||||||
let content_message = call
|
Ok(ChatChoice::NoTool)
|
||||||
.function
|
|
||||||
.arguments
|
|
||||||
.get("content")
|
|
||||||
.and_then(Value::as_str)
|
|
||||||
.ok_or_else(|| {
|
|
||||||
InferError::ToolError("No `content` found in generated text".to_string())
|
|
||||||
})?
|
|
||||||
.to_string();
|
|
||||||
Ok((None, Some(content_message)))
|
|
||||||
}
|
}
|
||||||
name => {
|
name => {
|
||||||
let tool_calls = vec![crate::ToolCall {
|
let tool_calls = vec![crate::ToolCall {
|
||||||
@ -73,7 +58,7 @@ pub(crate) fn parse_output(
|
|||||||
})?,
|
})?,
|
||||||
},
|
},
|
||||||
}];
|
}];
|
||||||
Ok((Some(tool_calls), None))
|
Ok(ChatChoice::ToolCalls(tool_calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -158,10 +143,6 @@ enum StreamState {
|
|||||||
Buffering,
|
Buffering,
|
||||||
/// We detected a tool call here
|
/// We detected a tool call here
|
||||||
Tool,
|
Tool,
|
||||||
/// During the `content` part of the tool call
|
|
||||||
NoTool,
|
|
||||||
/// Finishing frames of the ToolCall
|
|
||||||
NoToolFinish,
|
|
||||||
/// This is without tool calling
|
/// This is without tool calling
|
||||||
Content,
|
Content,
|
||||||
}
|
}
|
||||||
@ -202,34 +183,16 @@ impl ChatState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
|
pub fn push(&mut self, mut stream_token: StreamResponse) -> ChatEvent {
|
||||||
let mut events = vec![];
|
let mut events = vec![];
|
||||||
let token_text = &stream_token.token.text;
|
let token_text = &stream_token.token.text;
|
||||||
match self.state {
|
match self.state {
|
||||||
StreamState::Buffering => {
|
StreamState::Buffering => {
|
||||||
self.text.push_str(token_text);
|
self.text.push_str(token_text);
|
||||||
// We have a special match for `no_tool` in order to capture directly the `content`
|
tracing::info!("Current text {:?}", self.text);
|
||||||
// key which should be re-emitted as raw text.
|
|
||||||
if let Ok(value) = serde_json::from_str::<NoTool>(&format!("{}\"}}}}", self.text)) {
|
|
||||||
self.state = StreamState::NoTool;
|
|
||||||
// Modifiy the content of the token to be whatever was captured by the JSON
|
|
||||||
stream_token.token.text = value.function.content;
|
|
||||||
let chat_complete = create_event_from_stream_token(
|
|
||||||
&stream_token,
|
|
||||||
self.logprobs,
|
|
||||||
false,
|
|
||||||
self.fingerprint.clone(),
|
|
||||||
self.model_id.clone(),
|
|
||||||
None,
|
|
||||||
self.id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
events.push(chat_complete);
|
|
||||||
}
|
|
||||||
// XXX Caution, here we do not postfix the quote, so that the current output
|
|
||||||
// Is necessarily finished with quotes for us to be able to parse.
|
|
||||||
let partial = &self.text;
|
let partial = &self.text;
|
||||||
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
|
let partial =
|
||||||
|
partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');
|
||||||
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
|
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
|
||||||
// This can be no_tool before the content has been emitted
|
// This can be no_tool before the content has been emitted
|
||||||
if call.function._name != "no_tool" {
|
if call.function._name != "no_tool" {
|
||||||
@ -246,6 +209,8 @@ impl ChatState {
|
|||||||
|
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
self.state = StreamState::Tool;
|
self.state = StreamState::Tool;
|
||||||
|
} else {
|
||||||
|
return ChatEvent::NoTool;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -282,50 +247,6 @@ impl ChatState {
|
|||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
|
||||||
// We have remainder tokens, ignore everying,
|
|
||||||
StreamState::NoToolFinish => {}
|
|
||||||
StreamState::NoTool => {
|
|
||||||
self.text.push_str(token_text);
|
|
||||||
if token_text.contains("\"") {
|
|
||||||
let mut text = self
|
|
||||||
.text
|
|
||||||
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
|
|
||||||
// Trim once
|
|
||||||
if text.ends_with("\"") {
|
|
||||||
// Verify we have actually trimmed something
|
|
||||||
// The opposite can happen if the model is outputting inline JSON.
|
|
||||||
text = &text[..text.len() - 1];
|
|
||||||
if let Ok(_value) =
|
|
||||||
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", text))
|
|
||||||
{
|
|
||||||
let mut text = token_text
|
|
||||||
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
|
|
||||||
// Effectively trim_end_match('"', 1)
|
|
||||||
// because we do not want to eventually trim finishing escaped quotes
|
|
||||||
// {{"\"Something\""}}
|
|
||||||
if text.ends_with("\"") {
|
|
||||||
text = &text[..text.len() - 1];
|
|
||||||
}
|
|
||||||
stream_token.token.text = text.to_string();
|
|
||||||
self.state = StreamState::NoToolFinish;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// This escaping is usually inline json escaping and we can therefore remove it.
|
|
||||||
stream_token.token.text = stream_token.token.text.replace("\\", "");
|
|
||||||
let chat_complete = create_event_from_stream_token(
|
|
||||||
&stream_token,
|
|
||||||
self.logprobs,
|
|
||||||
false,
|
|
||||||
self.fingerprint.clone(),
|
|
||||||
self.model_id.clone(),
|
|
||||||
None,
|
|
||||||
self.id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
events.push(chat_complete);
|
|
||||||
}
|
|
||||||
StreamState::Content => {
|
StreamState::Content => {
|
||||||
let chat_complete = create_event_from_stream_token(
|
let chat_complete = create_event_from_stream_token(
|
||||||
&stream_token,
|
&stream_token,
|
||||||
@ -373,7 +294,7 @@ impl ChatState {
|
|||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
events
|
ChatEvent::Events(events)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,24 +306,6 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn get_text_content(event: &CompletionType) -> &String {
|
|
||||||
match event {
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
|
||||||
assert_eq!(choices.len(), 1);
|
|
||||||
if let ChatCompletionChoice {
|
|
||||||
delta: ChatCompletionDelta::Chat(TextMessage { content, .. }),
|
|
||||||
..
|
|
||||||
} = &choices[0]
|
|
||||||
{
|
|
||||||
content
|
|
||||||
} else {
|
|
||||||
panic!("Expected plain message");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => panic!("Unexpected chunk"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
|
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
|
||||||
match event {
|
match event {
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||||
@ -456,6 +359,7 @@ mod tests {
|
|||||||
index: 0,
|
index: 0,
|
||||||
details: None,
|
details: None,
|
||||||
});
|
});
|
||||||
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 1);
|
assert_eq!(events.len(), 1);
|
||||||
match &events[0] {
|
match &events[0] {
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||||
@ -475,6 +379,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
_ => panic!("Unexpected chunk"),
|
_ => panic!("Unexpected chunk"),
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Expected chat events");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -507,6 +414,7 @@ mod tests {
|
|||||||
finish_reason: FinishReason::Length,
|
finish_reason: FinishReason::Length,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 2);
|
assert_eq!(events.len(), 2);
|
||||||
match &events[0] {
|
match &events[0] {
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||||
@ -540,10 +448,13 @@ mod tests {
|
|||||||
}
|
}
|
||||||
_ => panic!("Unexpected chunk"),
|
_ => panic!("Unexpected chunk"),
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Expected chat events");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_stream_tool_no_tool() {
|
fn test_chat_stream_tool_no_tool_simple() {
|
||||||
let mut chat_state = ChatState::new(
|
let mut chat_state = ChatState::new(
|
||||||
true,
|
true,
|
||||||
StreamOptions {
|
StreamOptions {
|
||||||
@ -597,217 +508,21 @@ mod tests {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Initial ignored output
|
// Initial ignored output
|
||||||
for token in &tokens[..14] {
|
for token in &tokens[..10] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
assert_eq!(events.len(), 0);
|
if let ChatEvent::Events(events) = events {
|
||||||
}
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
|
|
||||||
// No tool output
|
|
||||||
let mut output = String::new();
|
|
||||||
for token in &tokens[14..14 + 7] {
|
|
||||||
let events = chat_state.push(token.clone());
|
|
||||||
assert_eq!(events.len(), 1);
|
|
||||||
let content = get_text_content(&events[0]);
|
|
||||||
output.push_str(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(output, "I am a helpful assistant!");
|
|
||||||
|
|
||||||
// No tool finish
|
|
||||||
for token in &tokens[14 + 7..] {
|
|
||||||
let events = chat_state.push(token.clone());
|
|
||||||
assert_eq!(events.len(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_stream_tool_no_tool_many_quotes() {
|
|
||||||
let mut chat_state = ChatState::new(
|
|
||||||
true,
|
|
||||||
StreamOptions {
|
|
||||||
include_usage: true,
|
|
||||||
},
|
|
||||||
"fingerprint".to_string(),
|
|
||||||
"model_id".to_string(),
|
|
||||||
false,
|
|
||||||
"0".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let tokens = vec![
|
|
||||||
"{\"".to_string(),
|
|
||||||
"function".to_string(),
|
|
||||||
"\":".to_string(),
|
|
||||||
" {\"".to_string(),
|
|
||||||
"_".to_string(),
|
|
||||||
"name".to_string(),
|
|
||||||
"\":".to_string(),
|
|
||||||
" \"".to_string(),
|
|
||||||
"no".to_string(),
|
|
||||||
"_tool".to_string(),
|
|
||||||
"\",".to_string(),
|
|
||||||
" \"".to_string(),
|
|
||||||
"content".to_string(),
|
|
||||||
"\":".to_string(),
|
|
||||||
" \"".to_string(), // Token 14
|
|
||||||
"I".to_string(), // Event 1
|
|
||||||
" am".to_string(), // Event 2
|
|
||||||
" a".to_string(), // Event 3
|
|
||||||
" helpful".to_string(), // Event 4
|
|
||||||
" assistant".to_string(), // Event 5
|
|
||||||
"!\\\"\"".to_string(), // Extra inside the string quote that would get removed
|
|
||||||
"}".to_string(),
|
|
||||||
"}".to_string(),
|
|
||||||
];
|
|
||||||
|
|
||||||
// Initial ignored output
|
|
||||||
for text in &tokens[..14] {
|
|
||||||
let events = chat_state.push(StreamResponse {
|
|
||||||
generated_text: None,
|
|
||||||
token: Token {
|
|
||||||
id: 42,
|
|
||||||
text: text.to_string(),
|
|
||||||
logprob: 0.0,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
index: 0,
|
|
||||||
details: None,
|
|
||||||
});
|
|
||||||
assert_eq!(events.len(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// No tool output
|
|
||||||
let mut output = String::new();
|
|
||||||
for text in &tokens[14..14 + 7] {
|
|
||||||
let events = chat_state.push(StreamResponse {
|
|
||||||
generated_text: None,
|
|
||||||
token: Token {
|
|
||||||
id: 42,
|
|
||||||
text: text.to_string(),
|
|
||||||
logprob: 0.0,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
index: 0,
|
|
||||||
details: None,
|
|
||||||
});
|
|
||||||
assert_eq!(events.len(), 1);
|
|
||||||
match &events[0] {
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
|
||||||
assert_eq!(choices.len(), 1);
|
|
||||||
if let ChatCompletionChoice {
|
|
||||||
delta: ChatCompletionDelta::Chat(TextMessage { content, .. }),
|
|
||||||
..
|
|
||||||
} = &choices[0]
|
|
||||||
{
|
|
||||||
output.push_str(content);
|
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected plain message");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => panic!("Unexpected chunk"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(output, "I am a helpful assistant!\"");
|
|
||||||
|
|
||||||
// No tool finish
|
|
||||||
for text in &tokens[14 + 7..] {
|
|
||||||
let events = chat_state.push(StreamResponse {
|
|
||||||
generated_text: None,
|
|
||||||
token: Token {
|
|
||||||
id: 42,
|
|
||||||
text: text.to_string(),
|
|
||||||
logprob: 0.0,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
index: 0,
|
|
||||||
details: None,
|
|
||||||
});
|
|
||||||
assert_eq!(events.len(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_stream_tool_no_tool_inline_json() {
|
|
||||||
let mut chat_state = ChatState::new(
|
|
||||||
true,
|
|
||||||
StreamOptions {
|
|
||||||
include_usage: true,
|
|
||||||
},
|
|
||||||
"fingerprint".to_string(),
|
|
||||||
"model_id".to_string(),
|
|
||||||
false,
|
|
||||||
"0".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let tokens = vec![
|
|
||||||
"{\"".to_string(),
|
|
||||||
"function".to_string(),
|
|
||||||
"\":".to_string(),
|
|
||||||
" {\"".to_string(),
|
|
||||||
"_".to_string(),
|
|
||||||
"name".to_string(),
|
|
||||||
"\":".to_string(),
|
|
||||||
" \"".to_string(),
|
|
||||||
"no".to_string(),
|
|
||||||
"_tool".to_string(),
|
|
||||||
"\",".to_string(),
|
|
||||||
" \"".to_string(),
|
|
||||||
"content".to_string(),
|
|
||||||
"\":".to_string(),
|
|
||||||
" \"".to_string(), // Token 14
|
|
||||||
"{\\\"".to_string(), // Event 1
|
|
||||||
"a".to_string(), // Event 1
|
|
||||||
"\\\":".to_string(), // Event 1
|
|
||||||
"2".to_string(), // Event 2
|
|
||||||
",\\".to_string(), // Event 2
|
|
||||||
"\"".to_string(), // Event 2
|
|
||||||
"b".to_string(), // Event 3
|
|
||||||
"\\\": ".to_string(), // Event 4
|
|
||||||
"1".to_string(), // Event 5
|
|
||||||
"}".to_string(), // Event 5
|
|
||||||
"\"}".to_string(), // Extra inside the string quote that would get removed
|
|
||||||
"}".to_string(),
|
|
||||||
];
|
|
||||||
let tokens: Vec<_> = tokens
|
|
||||||
.into_iter()
|
|
||||||
.map(|text| StreamResponse {
|
|
||||||
generated_text: None,
|
|
||||||
token: Token {
|
|
||||||
id: 42,
|
|
||||||
text: text.to_string(),
|
|
||||||
logprob: 0.0,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
index: 0,
|
|
||||||
details: None,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Initial ignored output
|
|
||||||
for token in &tokens[..14] {
|
|
||||||
let events = chat_state.push(token.clone());
|
|
||||||
assert_eq!(events.len(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// No tool output
|
// No tool output
|
||||||
let mut output = String::new();
|
let events = chat_state.push(tokens[10].clone());
|
||||||
for token in &tokens[14..14 + 12] {
|
if let ChatEvent::NoTool = events {
|
||||||
let events = chat_state.push(token.clone());
|
assert!(true);
|
||||||
assert_eq!(events.len(), 1, "Current text is {output:?}");
|
} else {
|
||||||
let content = get_text_content(&events[0]);
|
panic!("Expected chat events");
|
||||||
output.push_str(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(output, "{\"a\":2,\"b\": 1}");
|
|
||||||
|
|
||||||
// No tool finish
|
|
||||||
for token in &tokens[14 + 12..] {
|
|
||||||
let events = chat_state.push(token.clone());
|
|
||||||
assert_eq!(events.len(), 0, "Extra events {events:?}");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -859,26 +574,21 @@ mod tests {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Initial ignored output
|
// Initial ignored output
|
||||||
for token in &tokens[..13] {
|
for token in &tokens[..10] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
assert_eq!(events.len(), 0);
|
if let ChatEvent::Events(events) = events {
|
||||||
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
|
} else {
|
||||||
|
panic!("Expected chat events");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No tool output
|
// No tool output
|
||||||
let mut output = String::new();
|
let events = chat_state.push(tokens[10].clone());
|
||||||
for token in &tokens[13..13 + 2] {
|
if let ChatEvent::NoTool = events {
|
||||||
let events = chat_state.push(token.clone());
|
assert!(true);
|
||||||
assert_eq!(events.len(), 1, "Current text is {output:?}");
|
} else {
|
||||||
let content = get_text_content(&events[0]);
|
panic!("Expected chat events");
|
||||||
output.push_str(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(output, "");
|
|
||||||
|
|
||||||
// No tool finish
|
|
||||||
for token in &tokens[13 + 2..] {
|
|
||||||
let events = chat_state.push(token.clone());
|
|
||||||
assert_eq!(events.len(), 0, "Extra events {events:?}");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -946,7 +656,11 @@ mod tests {
|
|||||||
// Initial ignored output
|
// Initial ignored output
|
||||||
for token in &tokens[..11] {
|
for token in &tokens[..11] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 0, "{events:?}");
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
|
} else {
|
||||||
|
panic!("Expected chat events");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No tool output
|
// No tool output
|
||||||
@ -954,6 +668,7 @@ mod tests {
|
|||||||
let mut output_name = String::new();
|
let mut output_name = String::new();
|
||||||
for token in &tokens[11..11 + 17] {
|
for token in &tokens[11..11 + 17] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 1);
|
assert_eq!(events.len(), 1);
|
||||||
let (name, arguments) = get_tool_call_content(&events[0]);
|
let (name, arguments) = get_tool_call_content(&events[0]);
|
||||||
if let Some(name) = name {
|
if let Some(name) = name {
|
||||||
@ -961,6 +676,9 @@ mod tests {
|
|||||||
output_name.push_str(&name);
|
output_name.push_str(&name);
|
||||||
}
|
}
|
||||||
output.push_str(arguments);
|
output.push_str(arguments);
|
||||||
|
} else {
|
||||||
|
panic!("Expected chat events");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(output_name, "get_current_weather");
|
assert_eq!(output_name, "get_current_weather");
|
||||||
@ -972,7 +690,11 @@ mod tests {
|
|||||||
// No tool finish
|
// No tool finish
|
||||||
for token in &tokens[11 + 17..] {
|
for token in &tokens[11 + 17..] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
assert_eq!(events.len(), 0);
|
if let ChatEvent::Events(events) = events {
|
||||||
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
|
} else {
|
||||||
|
panic!("Expected chat events");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,13 +40,13 @@ impl ToolGrammar {
|
|||||||
),
|
),
|
||||||
arguments: json!({
|
arguments: json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
// "properties": {
|
||||||
"content": {
|
// "content": {
|
||||||
"type": "string",
|
// "type": "string",
|
||||||
"description": "The response content",
|
// "description": "The response content",
|
||||||
}
|
// }
|
||||||
},
|
// },
|
||||||
"required": ["content"]
|
// "required": ["content"]
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::chat::ChatState;
|
use crate::chat::{ChatChoice, ChatEvent, ChatState};
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
||||||
@ -1151,7 +1151,7 @@ pub(crate) async fn chat_completions(
|
|||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(chat): Json<ChatRequest>,
|
Json(mut chat): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
@ -1166,7 +1166,7 @@ pub(crate) async fn chat_completions(
|
|||||||
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
||||||
let id = chat.next_tool_call_id();
|
let id = chat.next_tool_call_id();
|
||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
chat.try_into_generate(&infer)?;
|
chat.clone().try_into_generate(&infer)?;
|
||||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||||
let logprobs = logprobs.unwrap_or_default();
|
let logprobs = logprobs.unwrap_or_default();
|
||||||
|
|
||||||
@ -1178,16 +1178,34 @@ pub(crate) async fn chat_completions(
|
|||||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
infer.clone(),
|
||||||
|
compute_type.clone(),
|
||||||
|
Json(generate_request),
|
||||||
|
span.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id);
|
let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
match result{
|
match result{
|
||||||
Ok(stream_token) => {
|
Ok(stream_token) => {
|
||||||
let events = state.push(stream_token);
|
let events = state.push(stream_token);
|
||||||
|
match events{
|
||||||
|
ChatEvent::NoTool => {
|
||||||
|
chat.tools = None;
|
||||||
|
chat.response_format = None;
|
||||||
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
|
chat.clone().try_into_generate(&infer).unwrap();
|
||||||
|
assert!(!using_tools);
|
||||||
|
let (_headers, response_stream2) =
|
||||||
|
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
|
||||||
|
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
||||||
|
response_stream = Box::pin(response_stream2);
|
||||||
|
}
|
||||||
|
ChatEvent::Events(events) => {
|
||||||
for chat_complete in events{
|
for chat_complete in events{
|
||||||
yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| {
|
yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||||
@ -1195,6 +1213,8 @@ pub(crate) async fn chat_completions(
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Err(err) => yield Ok(err.into_openai_event())
|
Err(err) => yield Ok(err.into_openai_event())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1204,8 +1224,13 @@ pub(crate) async fn chat_completions(
|
|||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, input_length, Json(generation)) =
|
let (mut headers, mut input_length, Json(generation)) = generate_internal(
|
||||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
Extension(infer.clone()),
|
||||||
|
compute_type.clone(),
|
||||||
|
Json(generate_request),
|
||||||
|
span.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -1213,7 +1238,26 @@ pub(crate) async fn chat_completions(
|
|||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if using_tools {
|
let (tool_calls, output) = if using_tools {
|
||||||
crate::chat::parse_output(&generation.generated_text)?
|
match crate::chat::parse_output(&generation.generated_text)? {
|
||||||
|
ChatChoice::NoTool => {
|
||||||
|
chat.tools = None;
|
||||||
|
chat.response_format = None;
|
||||||
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
|
chat.clone().try_into_generate(&infer)?;
|
||||||
|
assert!(!using_tools);
|
||||||
|
let (headers_final, input_length_final, Json(generation)) = generate_internal(
|
||||||
|
Extension(infer),
|
||||||
|
compute_type,
|
||||||
|
Json(generate_request),
|
||||||
|
span,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
headers = headers_final;
|
||||||
|
input_length = input_length_final;
|
||||||
|
(None, Some(generation.generated_text))
|
||||||
|
}
|
||||||
|
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
(None, Some(generation.generated_text))
|
(None, Some(generation.generated_text))
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user