From f691a945aac1eb3f1b339366e1c943e68438893a Mon Sep 17 00:00:00 2001 From: phangiabao98 <60313144+phangiabao98@users.noreply.github.com> Date: Thu, 16 May 2024 15:17:00 +0700 Subject: [PATCH] OpenAI function calling compatible support (#1888) # What does this PR do? Fixes # (issue) https://github.com/huggingface/text-generation-inference/issues/1887 ## Before submitting - [no ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [yes] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ yes] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [yes ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ yes] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --> --------- Co-authored-by: Bao Phan --- router/src/infer.rs | 23 +++++++++++++++++++++++ router/src/lib.rs | 8 ++++++-- router/src/server.rs | 3 +-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 85e8775e..9646deb9 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -990,24 +990,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1060,30 +1064,35 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("Hi again!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1141,24 +1150,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1200,24 +1213,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1245,18 +1262,21 @@ mod tests { content: Some("Hello, how are you?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("I'm doing great. How can I help you today?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("I'd like to show off how chat templating works!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ]; @@ -1268,6 +1288,7 @@ mod tests { ), name: None, tool_calls: None, + tool_call_id: None, }] .iter() .chain(&example_chat) @@ -1412,12 +1433,14 @@ mod tests { content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("How many helicopters can a human eat in one sitting?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], add_generation_prompt: true, diff --git a/router/src/lib.rs b/router/src/lib.rs index 96a9fdf6..85e18dfb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -546,6 +546,7 @@ impl ChatCompletion { content: output, name: None, tool_calls, + tool_call_id: None, }, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), @@ -881,7 +882,7 @@ pub(crate) struct ChatTemplateInputs<'a> { #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] pub(crate) struct ToolCall { - pub id: u32, + pub id: String, pub r#type: String, pub function: FunctionDefinition, } @@ -954,13 +955,16 @@ pub(crate) struct Message { pub role: String, #[serde(skip_serializing_if = "Option::is_none")] #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] + #[serde(default, deserialize_with = "message_content_serde::deserialize")] pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] pub name: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(example = "\"get_weather\"")] + pub tool_call_id: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 6b51109b..52652b72 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -990,7 +990,6 @@ async fn chat_completions( ) -> Result)> { let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); - let ChatRequest { logprobs, max_tokens, @@ -1162,7 +1161,7 @@ async fn chat_completions( ) })?; let tool_calls = vec![ToolCall { - id: 0, + id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None,