mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-07 09:52:18 +00:00
Improve tool call message processing (#3036)
* make content field optional in chat request * add tool_calls field to Message struct * feat: add test and serialize tool messages * fix: bump utopia, openapi doc version and improve test * fix: rerun update docs * fix: suppoer tool call id in template and remove unnecessary changes * fix: ruff lint remove unused import * fix: adjust message types in tests --------- Co-authored-by: sailesh duddupudi <saileshradar@gmail.com>
This commit is contained in:
parent
3498f6085e
commit
1cae3197c4
@ -1865,25 +1865,57 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"Message": {
|
"Message": {
|
||||||
"type": "object",
|
"allOf": [
|
||||||
"required": [
|
{
|
||||||
"role",
|
"$ref": "#/components/schemas/MessageBody"
|
||||||
"content"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"content": {
|
|
||||||
"$ref": "#/components/schemas/MessageContent"
|
|
||||||
},
|
},
|
||||||
"name": {
|
{
|
||||||
"type": "string",
|
"type": "object",
|
||||||
"example": "\"David\"",
|
"required": [
|
||||||
"nullable": true
|
"role"
|
||||||
},
|
],
|
||||||
"role": {
|
"properties": {
|
||||||
"type": "string",
|
"name": {
|
||||||
"example": "user"
|
"type": "string",
|
||||||
|
"example": "\"David\"",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
]
|
||||||
|
},
|
||||||
|
"MessageBody": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"content"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"$ref": "#/components/schemas/MessageContent"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"tool_calls"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"tool_calls": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ToolCall"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"MessageChunk": {
|
"MessageChunk": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
@ -2179,6 +2211,10 @@
|
|||||||
"role": {
|
"role": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "user"
|
"example": "user"
|
||||||
|
},
|
||||||
|
"tool_call_id": {
|
||||||
|
"type": "string",
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1739932427,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 79,
|
||||||
|
"prompt_tokens": 103,
|
||||||
|
"total_tokens": 182
|
||||||
|
}
|
||||||
|
}
|
@ -468,3 +468,41 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
|||||||
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_tool_reply_response(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris today?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "0",
|
||||||
|
"function": {
|
||||||
|
"arguments": '{"longitude": 2.2945, "latitude": 48.8567}',
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": None,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "0", "content": "6.7"},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses.choices[0].message.tool_calls is None
|
||||||
|
assert (
|
||||||
|
responses.choices[0].message.content
|
||||||
|
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
use crate::{
|
||||||
|
ChatTemplateInputs, Message, MessageBody, MessageChunk, TextMessage, TokenizerConfigToken, Tool,
|
||||||
|
};
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
@ -74,7 +76,9 @@ impl ChatTemplate {
|
|||||||
format!("\n---\n{}", tool_prompt)
|
format!("\n---\n{}", tool_prompt)
|
||||||
};
|
};
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
last_message.content.push(MessageChunk::Text { text });
|
if let MessageBody::Content { content } = &mut last_message.body {
|
||||||
|
content.push(MessageChunk::Text { text });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Some(tools)
|
Some(tools)
|
||||||
}
|
}
|
||||||
@ -119,7 +123,8 @@ mod tests {
|
|||||||
use crate::infer::chat_template::{raise_exception, strftime_now};
|
use crate::infer::chat_template::{raise_exception, strftime_now};
|
||||||
use crate::infer::ChatTemplate;
|
use crate::infer::ChatTemplate;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
|
ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage,
|
||||||
|
TokenizerConfigToken, Tool,
|
||||||
};
|
};
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
@ -158,18 +163,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -186,6 +195,182 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_with_tool_response() {
|
||||||
|
let env = Environment::new();
|
||||||
|
|
||||||
|
// template modified from Llama-3.1-8B-Instruct
|
||||||
|
// https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/0e9e39f249a16976918f6564b8830bc894c89659/tokenizer_config.json#L2053
|
||||||
|
// the main change is accesing `message.tool_call_id` from the messages
|
||||||
|
let source = r#"
|
||||||
|
{{- bos_token }}
|
||||||
|
{%- if custom_tools is defined %}
|
||||||
|
{%- set tools = custom_tools %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if not tools_in_user_message is defined %}
|
||||||
|
{%- set tools_in_user_message = true %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if not date_string is defined %}
|
||||||
|
{%- set date_string = "26 Jul 2024" %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if not tools is defined %}
|
||||||
|
{%- set tools = none %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- This block extracts the system message, so we can slot it into the right place. #}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{%- set system_message = messages[0]['content']|trim %}
|
||||||
|
{%- set messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = "" %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- System message + builtin tools #}
|
||||||
|
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
|
||||||
|
{%- if builtin_tools is defined or tools is not none %}
|
||||||
|
{{- "Environment: ipython\n" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if builtin_tools is defined %}
|
||||||
|
{{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "Cutting Knowledge Date: December 2023\n" }}
|
||||||
|
{{- "Today Date: " + date_string + "\n\n" }}
|
||||||
|
{%- if tools is not none and not tools_in_user_message %}
|
||||||
|
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
|
||||||
|
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
||||||
|
{{- "Do not use variables.\n\n" }}
|
||||||
|
{%- for t in tools %}
|
||||||
|
{{- t | tojson(indent=4) }}
|
||||||
|
{{- "\n\n" }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- system_message }}
|
||||||
|
{{- "<|eot_id|>" }}
|
||||||
|
|
||||||
|
{#- Custom tools are passed in a user message with some extra guidance #}
|
||||||
|
{%- if tools_in_user_message and not tools is none %}
|
||||||
|
{#- Extract the first user message so we can plug it in here #}
|
||||||
|
{%- if messages | length != 0 %}
|
||||||
|
{%- set first_user_message = messages[0]['content']|trim %}
|
||||||
|
{%- set messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
|
||||||
|
{{- "Given the following functions, please respond with a JSON for a function call " }}
|
||||||
|
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
|
||||||
|
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
||||||
|
{{- "Do not use variables.\n\n" }}
|
||||||
|
{%- for t in tools %}
|
||||||
|
{{- t | tojson(indent=4) }}
|
||||||
|
{{- "\n\n" }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- first_user_message + "<|eot_id|>"}}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
||||||
|
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
|
||||||
|
{%- elif 'tool_calls' in message %}
|
||||||
|
{%- if not message.tool_calls|length == 1 %}
|
||||||
|
{{- raise_exception("This model only supports single tool-calls at once!") }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set tool_call = message.tool_calls[0].function %}
|
||||||
|
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
||||||
|
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
|
||||||
|
{%- for arg_name, arg_val in tool_call.arguments | items %}
|
||||||
|
{{- arg_name + '="' + arg_val + '"' }}
|
||||||
|
{%- if not loop.last %}
|
||||||
|
{{- ", " }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- ")" }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
||||||
|
{{- '{"name": "' + tool_call.name + '", ' }}
|
||||||
|
{{- '"parameters": ' }}
|
||||||
|
{{- tool_call.arguments | tojson }}
|
||||||
|
{{- "}" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if builtin_tools is defined %}
|
||||||
|
{#- This means we're in ipython mode #}
|
||||||
|
{{- "<|eom_id|>" }}
|
||||||
|
{%- else %}
|
||||||
|
{{- "<|eot_id|>" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- elif message.role == "tool" or message.role == "ipython" %}
|
||||||
|
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
|
||||||
|
{{- "TOOL CALL ID: " + message.tool_call_id + "\n\n" }}
|
||||||
|
{%- if message.content is mapping or message.content is iterable %}
|
||||||
|
{{- message.content | tojson }}
|
||||||
|
{%- else %}
|
||||||
|
{{- message.content }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "<|eot_id|>" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let source = source
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&source);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: vec![
|
||||||
|
TextMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
TextMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: r#"[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]"#.to_string(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
TextMessage {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content: "6.7".to_string(),
|
||||||
|
tool_call_id: Some("0".to_string()),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
bos_token: Some("[BOS]"),
|
||||||
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
r#"[BOS]<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 26 Jul 2024
|
||||||
|
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Hi!<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
TOOL CALL ID: 0
|
||||||
|
|
||||||
|
"6.7"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_template_loop_controls() {
|
fn test_chat_template_loop_controls() {
|
||||||
// some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`
|
// some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`
|
||||||
@ -224,18 +409,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -287,22 +476,27 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi again!".to_string(),
|
content: "Hi again!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -359,18 +553,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -426,18 +624,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -479,18 +681,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -516,14 +722,17 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hello, how are you?".to_string(),
|
content: "Hello, how are you?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "I'm doing great. How can I help you today?".to_string(),
|
content: "I'm doing great. How can I help you today?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "I'd like to show off how chat templating works!".to_string(),
|
content: "I'd like to show off how chat templating works!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -531,6 +740,7 @@ mod tests {
|
|||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: "You are a friendly chatbot who always responds in the style of a pirate"
|
content: "You are a friendly chatbot who always responds in the style of a pirate"
|
||||||
.to_string(),
|
.to_string(),
|
||||||
|
..Default::default()
|
||||||
}]
|
}]
|
||||||
.iter()
|
.iter()
|
||||||
.chain(&example_chat)
|
.chain(&example_chat)
|
||||||
@ -674,10 +884,12 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
@ -949,19 +1161,27 @@ mod tests {
|
|||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::SingleText(
|
body: MessageBody::Content {
|
||||||
"I'd like to show off how chat templating works!".to_string(),
|
content: MessageContent::SingleText(
|
||||||
),
|
"I'd like to show off how chat templating works!".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
|
body: MessageBody::Content {
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"Great! How can I help you today?".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::SingleText("Just testing".to_string()),
|
body: MessageBody::Content {
|
||||||
|
content: MessageContent::SingleText("Just testing".to_string()),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
@ -985,17 +1205,21 @@ mod tests {
|
|||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: MessageContent::SingleText(
|
body: MessageBody::Content {
|
||||||
"Youre a helpful assistant! Answer the users question best you can."
|
content: MessageContent::SingleText(
|
||||||
.to_string(),
|
"Youre a helpful assistant! Answer the users question best you can."
|
||||||
),
|
.to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::SingleText(
|
body: MessageBody::Content {
|
||||||
"What is the weather like in Brooklyn, New York?".to_string(),
|
content: MessageContent::SingleText(
|
||||||
),
|
"What is the weather like in Brooklyn, New York?".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
|
@ -663,6 +663,7 @@ impl ChatCompletion {
|
|||||||
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
|
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content,
|
content,
|
||||||
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
|
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
@ -673,6 +674,7 @@ impl ChatCompletion {
|
|||||||
OutputMessage::ChatMessage(TextMessage {
|
OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: output,
|
content: output,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
(None, None) => {
|
(None, None) => {
|
||||||
@ -680,6 +682,7 @@ impl ChatCompletion {
|
|||||||
OutputMessage::ChatMessage(TextMessage {
|
OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: "".to_string(),
|
content: "".to_string(),
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -767,6 +770,7 @@ impl ChatCompletionChunk {
|
|||||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: delta,
|
content: delta,
|
||||||
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
@ -783,6 +787,7 @@ impl ChatCompletionChunk {
|
|||||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "".to_string(),
|
content: "".to_string(),
|
||||||
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
@ -1025,7 +1030,7 @@ pub fn default_tool_prompt() -> String {
|
|||||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, PartialEq, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum TypedChoice {
|
pub enum TypedChoice {
|
||||||
#[serde(rename = "function")]
|
#[serde(rename = "function")]
|
||||||
@ -1100,19 +1105,19 @@ pub struct JsonSchemaTool {
|
|||||||
properties: Properties,
|
properties: Properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)]
|
||||||
struct FunctionsMap {
|
struct FunctionsMap {
|
||||||
#[serde(rename = "$functions")]
|
#[serde(rename = "$functions")]
|
||||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)]
|
||||||
struct FunctionRef {
|
struct FunctionRef {
|
||||||
#[serde(rename = "$ref")]
|
#[serde(rename = "$ref")]
|
||||||
ref_path: String,
|
ref_path: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)]
|
||||||
struct Properties {
|
struct Properties {
|
||||||
#[serde(serialize_with = "serialize_function")]
|
#[serde(serialize_with = "serialize_function")]
|
||||||
function: Vec<FunctionRef>,
|
function: Vec<FunctionRef>,
|
||||||
@ -1129,7 +1134,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
||||||
pub(crate) struct FunctionDefinition {
|
pub struct FunctionDefinition {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
@ -1157,7 +1162,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
||||||
pub(crate) struct ToolCall {
|
pub struct ToolCall {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
pub function: FunctionDefinition,
|
pub function: FunctionDefinition,
|
||||||
@ -1176,15 +1181,31 @@ pub enum MessageChunk {
|
|||||||
ImageUrl { image_url: Url },
|
ImageUrl { image_url: Url },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
role: String,
|
pub role: String,
|
||||||
|
#[serde(flatten)]
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: MessageContent,
|
pub body: MessageBody,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum MessageBody {
|
||||||
|
// When a regular text message is provided.
|
||||||
|
Content {
|
||||||
|
#[serde(rename = "content")]
|
||||||
|
content: MessageContent,
|
||||||
|
},
|
||||||
|
// When tool calls are provided.
|
||||||
|
Tool {
|
||||||
|
#[serde(rename = "tool_calls")]
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
@ -1211,19 +1232,28 @@ impl MessageContent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq, Default)]
|
||||||
pub struct TextMessage {
|
pub struct TextMessage {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Message> for TextMessage {
|
impl From<Message> for TextMessage {
|
||||||
fn from(value: Message) -> Self {
|
fn from(value: Message) -> Self {
|
||||||
|
let content = match value.body {
|
||||||
|
MessageBody::Content { content } => content,
|
||||||
|
MessageBody::Tool { tool_calls } => {
|
||||||
|
let content = serde_json::to_string(&tool_calls).unwrap_or_default();
|
||||||
|
MessageContent::SingleText(content)
|
||||||
|
}
|
||||||
|
};
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: value.role,
|
role: value.role,
|
||||||
content: match value.content {
|
content: match content {
|
||||||
MessageContent::SingleText(text) => text,
|
MessageContent::SingleText(text) => text,
|
||||||
MessageContent::MultipleChunks(chunks) => chunks
|
MessageContent::MultipleChunks(chunks) => chunks
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -1234,6 +1264,7 @@ impl From<Message> for TextMessage {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
},
|
},
|
||||||
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1565,9 +1596,11 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
request.messages[0],
|
request.messages[0],
|
||||||
Message {
|
Message {
|
||||||
|
name: None,
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::SingleText("What is Deep Learning?".to_string()),
|
body: MessageBody::Content {
|
||||||
name: None
|
content: MessageContent::SingleText("What is Deep Learning?".to_string())
|
||||||
|
},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -1617,13 +1650,16 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request.messages[0],
|
request.messages[0],
|
||||||
Message{
|
Message {
|
||||||
|
name: None,
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::MultipleChunks(vec![
|
|
||||||
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
body: MessageBody::Content {
|
||||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
|
content: MessageContent::MultipleChunks(vec![
|
||||||
]),
|
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||||
name: None
|
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
|
||||||
|
]),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -1631,12 +1667,14 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn text_message_convert() {
|
fn text_message_convert() {
|
||||||
let message = Message{
|
let message = Message{
|
||||||
|
name: None,
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::MultipleChunks(vec![
|
body: MessageBody::Content {
|
||||||
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
content: MessageContent::MultipleChunks(vec![
|
||||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
|
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||||
]),
|
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
|
||||||
name: None
|
]),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let textmsg: TextMessage = message.into();
|
let textmsg: TextMessage = message.into();
|
||||||
assert_eq!(textmsg.content, "Whats in this image?");
|
assert_eq!(textmsg.content, "Whats in this image?");
|
||||||
@ -1667,6 +1705,7 @@ mod tests {
|
|||||||
let message = OutputMessage::ChatMessage(TextMessage {
|
let message = OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "This is the answer".to_string(),
|
content: "This is the answer".to_string(),
|
||||||
|
..Default::default()
|
||||||
});
|
});
|
||||||
let serialized = serde_json::to_string(&message).unwrap();
|
let serialized = serde_json::to_string(&message).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -28,7 +28,7 @@ use crate::{
|
|||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{MessageBody, ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::{DefaultBodyLimit, Extension};
|
use axum::extract::{DefaultBodyLimit, Extension};
|
||||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||||
@ -1577,6 +1577,7 @@ FunctionDefinition,
|
|||||||
ToolChoice,
|
ToolChoice,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ChatTokenizeResponse,
|
ChatTokenizeResponse,
|
||||||
|
MessageBody,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
|
@ -147,7 +147,7 @@ pub(crate) async fn vertex_compatibility(
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{Message, MessageContent};
|
use crate::{Message, MessageBody, MessageContent};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vertex_deserialization() {
|
fn vertex_deserialization() {
|
||||||
@ -169,9 +169,13 @@ mod tests {
|
|||||||
VertexRequest {
|
VertexRequest {
|
||||||
instances: vec![VertexInstance::Chat(ChatRequest {
|
instances: vec![VertexInstance::Chat(ChatRequest {
|
||||||
messages: vec![Message {
|
messages: vec![Message {
|
||||||
role: "user".to_string(),
|
|
||||||
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
|
||||||
name: None,
|
name: None,
|
||||||
|
role: "user".to_string(),
|
||||||
|
body: MessageBody::Content {
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"What's Deep Learning?".to_string()
|
||||||
|
)
|
||||||
|
},
|
||||||
},],
|
},],
|
||||||
max_tokens: Some(128),
|
max_tokens: Some(128),
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
|
Loading…
Reference in New Issue
Block a user