diff --git a/docs/openapi.json b/docs/openapi.json index a1df080b..9de76e47 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1865,25 +1865,57 @@ } }, "Message": { - "type": "object", - "required": [ - "role", - "content" - ], - "properties": { - "content": { - "$ref": "#/components/schemas/MessageContent" + "allOf": [ + { + "$ref": "#/components/schemas/MessageBody" }, - "name": { - "type": "string", - "example": "\"David\"", - "nullable": true - }, - "role": { - "type": "string", - "example": "user" + { + "type": "object", + "required": [ + "role" + ], + "properties": { + "name": { + "type": "string", + "example": "\"David\"", + "nullable": true + }, + "role": { + "type": "string", + "example": "user" + } + } } - } + ] + }, + "MessageBody": { + "oneOf": [ + { + "type": "object", + "required": [ + "content" + ], + "properties": { + "content": { + "$ref": "#/components/schemas/MessageContent" + } + } + }, + { + "type": "object", + "required": [ + "tool_calls" + ], + "properties": { + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + } + } + } + } + ] }, "MessageChunk": { "oneOf": [ @@ -2179,6 +2211,10 @@ "role": { "type": "string", "example": "user" + }, + "tool_call_id": { + "type": "string", + "nullable": true } } }, diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json new file mode 100644 index 00000000..4f10aa3b --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1739932427, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.1.1-dev0-native", + "usage": { + "completion_tokens": 79, + "prompt_tokens": 103, + "total_tokens": 182 + } +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 70c3aff0..b8a90cff 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -468,3 +468,41 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( == '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>' ) assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_tool_reply_response( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=42, + messages=[ + {"role": "user", "content": "What's the weather like in Paris today?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "0", + "function": { + "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', + "name": "get_weather", + "description": None, + }, + "type": "function", + } + ], + }, + {"role": "tool", "tool_call_id": "0", "content": "6.7"}, + ], + stream=False, + ) + + assert responses.choices[0].message.tool_calls is None + assert ( + responses.choices[0].message.content + == "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information." + ) + + assert responses == response_snapshot diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 2fef2dcb..e660cc74 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,5 +1,7 @@ use crate::infer::InferError; -use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; +use crate::{ + ChatTemplateInputs, Message, MessageBody, MessageChunk, TextMessage, TokenizerConfigToken, Tool, +}; use chrono::Local; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -74,7 +76,9 @@ impl ChatTemplate { format!("\n---\n{}", tool_prompt) }; if let Some(last_message) = messages.last_mut() { - last_message.content.push(MessageChunk::Text { text }); + if let MessageBody::Content { content } = &mut last_message.body { + content.push(MessageChunk::Text { text }); + } } Some(tools) } @@ -119,7 +123,8 @@ mod tests { use crate::infer::chat_template::{raise_exception, strftime_now}; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, + ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage, + TokenizerConfigToken, Tool, }; use chrono::Local; use minijinja::Environment; @@ -158,18 +163,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -186,6 +195,182 @@ mod tests { ); } + #[test] + fn test_chat_template_with_tool_response() { + let env = Environment::new(); + + // template modified from Llama-3.1-8B-Instruct + // https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/0e9e39f249a16976918f6564b8830bc894c89659/tokenizer_config.json#L2053 + // the main change is accesing `message.tool_call_id` from the messages + let source = r#" + {{- bos_token }} + {%- if custom_tools is defined %} + {%- set tools = custom_tools %} + {%- endif %} + {%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} + {%- endif %} + {%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} + {%- if not tools is defined %} + {%- set tools = none %} + {%- endif %} + + {#- This block extracts the system message, so we can slot it into the right place. #} + {%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} + + {#- System message + builtin tools #} + {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} + {%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} + {%- endif %} + {%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} + {%- endif %} + {{- "Cutting Knowledge Date: December 2023\n" }} + {{- "Today Date: " + date_string + "\n\n" }} + {%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- system_message }} + {{- "<|eot_id|>" }} + + {#- Custom tools are passed in a user message with some extra guidance #} + {%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} + {%- endif %} + + {%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {{- "TOOL CALL ID: " + message.tool_call_id + "\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- endfor %} + {%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} + {%- endif %} + "#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + ..Default::default() + }, + TextMessage { + role: "assistant".to_string(), + content: r#"[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]"#.to_string(), + ..Default::default() + }, + TextMessage { + role: "tool".to_string(), + content: "6.7".to_string(), + tool_call_id: Some("0".to_string()), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + + assert_eq!( + result, + r#"[BOS]<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 +Today Date: 26 Jul 2024 + +<|eot_id|><|start_header_id|>user<|end_header_id|> + +Hi!<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]<|eot_id|><|start_header_id|>ipython<|end_header_id|> + +TOOL CALL ID: 0 + +"6.7"<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +"# + ); + } + #[test] fn test_chat_template_loop_controls() { // some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break` @@ -224,18 +409,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -287,22 +476,27 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "Hi again!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -359,18 +553,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -426,18 +624,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -479,18 +681,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -516,14 +722,17 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hello, how are you?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "I'm doing great. How can I help you today?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "I'd like to show off how chat templating works!".to_string(), + ..Default::default() }, ]; @@ -531,6 +740,7 @@ mod tests { role: "system".to_string(), content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), + ..Default::default() }] .iter() .chain(&example_chat) @@ -674,10 +884,12 @@ mod tests { TextMessage { role: "system".to_string(), content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "How many helicopters can a human eat in one sitting?".to_string(), + ..Default::default() }, ], add_generation_prompt: true, @@ -949,19 +1161,27 @@ mod tests { Message { name: None, role: "user".to_string(), - content: MessageContent::SingleText( - "I'd like to show off how chat templating works!".to_string(), - ), + body: MessageBody::Content { + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, }, Message { name: None, role: "assistant".to_string(), - content: MessageContent::SingleText("Great! How can I help you today?".to_string()), + body: MessageBody::Content { + content: MessageContent::SingleText( + "Great! How can I help you today?".to_string(), + ), + }, }, Message { name: None, role: "user".to_string(), - content: MessageContent::SingleText("Just testing".to_string()), + body: MessageBody::Content { + content: MessageContent::SingleText("Just testing".to_string()), + }, }, ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); @@ -985,17 +1205,21 @@ mod tests { Message { name: None, role: "system".to_string(), - content: MessageContent::SingleText( - "Youre a helpful assistant! Answer the users question best you can." - .to_string(), - ), + body: MessageBody::Content { + content: MessageContent::SingleText( + "Youre a helpful assistant! Answer the users question best you can." + .to_string(), + ), + }, }, Message { name: None, role: "user".to_string(), - content: MessageContent::SingleText( - "What is the weather like in Brooklyn, New York?".to_string(), - ), + body: MessageBody::Content { + content: MessageContent::SingleText( + "What is the weather like in Brooklyn, New York?".to_string(), + ), + }, }, ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); diff --git a/router/src/lib.rs b/router/src/lib.rs index 414d38ed..e8c875a8 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -663,6 +663,7 @@ impl ChatCompletion { (Some(content), None) => OutputMessage::ChatMessage(TextMessage { role: "assistant".into(), content, + ..Default::default() }), (None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { role: "assistant".to_string(), @@ -673,6 +674,7 @@ impl ChatCompletion { OutputMessage::ChatMessage(TextMessage { role: "assistant".into(), content: output, + ..Default::default() }) } (None, None) => { @@ -680,6 +682,7 @@ impl ChatCompletion { OutputMessage::ChatMessage(TextMessage { role: "assistant".into(), content: "".to_string(), + ..Default::default() }) } }; @@ -767,6 +770,7 @@ impl ChatCompletionChunk { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { role: "assistant".to_string(), content: delta, + ..Default::default() }), (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { role: "assistant".to_string(), @@ -783,6 +787,7 @@ impl ChatCompletionChunk { (None, None) => ChatCompletionDelta::Chat(TextMessage { role: "assistant".to_string(), content: "".to_string(), + ..Default::default() }), }; Self { @@ -1025,7 +1030,7 @@ pub fn default_tool_prompt() -> String { "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[derive(Clone, Debug, Deserialize, ToSchema, PartialEq, Serialize)] #[serde(tag = "type")] pub enum TypedChoice { #[serde(rename = "function")] @@ -1100,19 +1105,19 @@ pub struct JsonSchemaTool { properties: Properties, } -#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)] struct FunctionsMap { #[serde(rename = "$functions")] functions: std::collections::HashMap, } -#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)] struct FunctionRef { #[serde(rename = "$ref")] ref_path: String, } -#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)] struct Properties { #[serde(serialize_with = "serialize_function")] function: Vec, @@ -1129,7 +1134,7 @@ where } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] -pub(crate) struct FunctionDefinition { +pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, @@ -1157,7 +1162,7 @@ pub(crate) struct ChatTemplateInputs<'a> { } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] -pub(crate) struct ToolCall { +pub struct ToolCall { pub id: String, pub r#type: String, pub function: FunctionDefinition, @@ -1176,15 +1181,31 @@ pub enum MessageChunk { ImageUrl { image_url: Url }, } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] pub struct Message { #[schema(example = "user")] - role: String, + pub role: String, + #[serde(flatten)] #[schema(example = "My name is David and I")] - pub content: MessageContent, + pub body: MessageBody, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] - name: Option, + pub name: Option, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] +#[serde(untagged)] +pub enum MessageBody { + // When a regular text message is provided. + Content { + #[serde(rename = "content")] + content: MessageContent, + }, + // When tool calls are provided. + Tool { + #[serde(rename = "tool_calls")] + tool_calls: Vec, + }, } #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] @@ -1211,19 +1232,28 @@ impl MessageContent { } } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq, Default)] pub struct TextMessage { #[schema(example = "user")] pub role: String, #[schema(example = "My name is David and I")] pub content: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } impl From for TextMessage { fn from(value: Message) -> Self { + let content = match value.body { + MessageBody::Content { content } => content, + MessageBody::Tool { tool_calls } => { + let content = serde_json::to_string(&tool_calls).unwrap_or_default(); + MessageContent::SingleText(content) + } + }; TextMessage { role: value.role, - content: match value.content { + content: match content { MessageContent::SingleText(text) => text, MessageContent::MultipleChunks(chunks) => chunks .into_iter() @@ -1234,6 +1264,7 @@ impl From for TextMessage { .collect::>() .join(""), }, + ..Default::default() } } } @@ -1565,9 +1596,11 @@ mod tests { assert_eq!( request.messages[0], Message { + name: None, role: "user".to_string(), - content: MessageContent::SingleText("What is Deep Learning?".to_string()), - name: None + body: MessageBody::Content { + content: MessageContent::SingleText("What is Deep Learning?".to_string()) + }, } ); } @@ -1617,13 +1650,16 @@ mod tests { assert_eq!( request.messages[0], - Message{ + Message { + name: None, role: "user".to_string(), - content: MessageContent::MultipleChunks(vec![ - MessageChunk::Text { text: "Whats in this image?".to_string() }, - MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, - ]), - name: None + + body: MessageBody::Content { + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, + ]), + }, } ); } @@ -1631,12 +1667,14 @@ mod tests { #[test] fn text_message_convert() { let message = Message{ + name: None, role: "user".to_string(), - content: MessageContent::MultipleChunks(vec![ - MessageChunk::Text { text: "Whats in this image?".to_string() }, - MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } - ]), - name: None + body: MessageBody::Content { + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } + ]), + } }; let textmsg: TextMessage = message.into(); assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); @@ -1667,6 +1705,7 @@ mod tests { let message = OutputMessage::ChatMessage(TextMessage { role: "assistant".to_string(), content: "This is the answer".to_string(), + ..Default::default() }); let serialized = serde_json::to_string(&message).unwrap(); assert_eq!( diff --git a/router/src/server.rs b/router/src/server.rs index e9d2fcf4..e9aa4612 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -28,7 +28,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; -use crate::{ModelInfo, ModelsInfo}; +use crate::{MessageBody, ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::{DefaultBodyLimit, Extension}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -1577,6 +1577,7 @@ FunctionDefinition, ToolChoice, ModelInfo, ChatTokenizeResponse, +MessageBody, ) ), tags( diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 0a8c2278..38695532 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -147,7 +147,7 @@ pub(crate) async fn vertex_compatibility( #[cfg(test)] mod tests { use super::*; - use crate::{Message, MessageContent}; + use crate::{Message, MessageBody, MessageContent}; #[test] fn vertex_deserialization() { @@ -169,9 +169,13 @@ mod tests { VertexRequest { instances: vec![VertexInstance::Chat(ChatRequest { messages: vec![Message { - role: "user".to_string(), - content: MessageContent::SingleText("What's Deep Learning?".to_string()), name: None, + role: "user".to_string(), + body: MessageBody::Content { + content: MessageContent::SingleText( + "What's Deep Learning?".to_string() + ) + }, },], max_tokens: Some(128), top_p: Some(0.95),