diff --git a/router/src/infer.rs b/router/src/infer.rs index 85e8775e..9646deb9 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -990,24 +990,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1060,30 +1064,35 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("Hi again!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1141,24 +1150,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1200,24 +1213,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1245,18 +1262,21 @@ mod tests { content: Some("Hello, how are you?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("I'm doing great. How can I help you today?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("I'd like to show off how chat templating works!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ]; @@ -1268,6 +1288,7 @@ mod tests { ), name: None, tool_calls: None, + tool_call_id: None, }] .iter() .chain(&example_chat) @@ -1412,12 +1433,14 @@ mod tests { content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("How many helicopters can a human eat in one sitting?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], add_generation_prompt: true, diff --git a/router/src/lib.rs b/router/src/lib.rs index 96a9fdf6..85e18dfb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -546,6 +546,7 @@ impl ChatCompletion { content: output, name: None, tool_calls, + tool_call_id: None, }, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), @@ -881,7 +882,7 @@ pub(crate) struct ChatTemplateInputs<'a> { #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] pub(crate) struct ToolCall { - pub id: u32, + pub id: String, pub r#type: String, pub function: FunctionDefinition, } @@ -954,13 +955,16 @@ pub(crate) struct Message { pub role: String, #[serde(skip_serializing_if = "Option::is_none")] #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] + #[serde(default, deserialize_with = "message_content_serde::deserialize")] pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] pub name: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(example = "\"get_weather\"")] + pub tool_call_id: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 6b51109b..52652b72 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -990,7 +990,6 @@ async fn chat_completions( ) -> Result)> { let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); - let ChatRequest { logprobs, max_tokens, @@ -1162,7 +1161,7 @@ async fn chat_completions( ) })?; let tool_calls = vec![ToolCall { - id: 0, + id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None,