This commit is contained in:
drbh 2025-04-08 16:25:23 +08:00 committed by GitHub
commit 34a3d09bb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 84 additions and 29 deletions

View File

@ -1520,6 +1520,10 @@
"type": "string", "type": "string",
"nullable": true "nullable": true
}, },
"id": {
"type": "string",
"nullable": true
},
"name": { "name": {
"type": "string" "type": "string"
} }
@ -1883,6 +1887,11 @@
"role": { "role": {
"type": "string", "type": "string",
"example": "user" "example": "user"
},
"tool_call_id": {
"type": "string",
"example": "10",
"nullable": true
} }
} }
} }

View File

@ -49,6 +49,7 @@ pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferErro
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
id: None,
description: None, description: None,
name: name.to_string(), name: name.to_string(),
arguments: serde_json::to_value(call.function.arguments).map_err(|err| { arguments: serde_json::to_value(call.function.arguments).map_err(|err| {

View File

@ -96,15 +96,21 @@ impl ChatTemplate {
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect(); let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
let final_message = messages.last().cloned(); let final_message = messages.last().cloned();
let mut rendered_template = self let template_inputs = ChatTemplateInputs {
.template
.render(ChatTemplateInputs {
messages, messages,
bos_token: self.bos_token.as_deref(), bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(), eos_token: self.eos_token.as_deref(),
add_generation_prompt: true, add_generation_prompt: true,
tools, tools,
}) };
// NOTE: initalizing `template_inputs` is helpful when JSON dumping the
// `ChatTemplateInputs` struct for debugging
// let template_inputs_as_json = serde_json::to_string(&template_inputs).unwrap();
let mut rendered_template = self
.template
.render(template_inputs)
.map_err(InferError::TemplateError)?; .map_err(InferError::TemplateError)?;
// if the last message is from the assistant, continue the generation prompt // if the last message is from the assistant, continue the generation prompt
@ -1175,6 +1181,7 @@ TOOL CALL ID: 0
"I'd like to show off how chat templating works!".to_string(), "I'd like to show off how chat templating works!".to_string(),
), ),
}, },
tool_call_id: None,
}, },
Message { Message {
name: None, name: None,
@ -1184,6 +1191,7 @@ TOOL CALL ID: 0
"Great! How can I help you today?".to_string(), "Great! How can I help you today?".to_string(),
), ),
}, },
tool_call_id: None,
}, },
Message { Message {
name: None, name: None,
@ -1191,6 +1199,7 @@ TOOL CALL ID: 0
body: MessageBody::Content { body: MessageBody::Content {
content: MessageContent::SingleText("Just testing".to_string()), content: MessageContent::SingleText("Just testing".to_string()),
}, },
tool_call_id: None,
}, },
]; ];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
@ -1220,6 +1229,7 @@ TOOL CALL ID: 0
.to_string(), .to_string(),
), ),
}, },
tool_call_id: None,
}, },
Message { Message {
name: None, name: None,
@ -1229,6 +1239,7 @@ TOOL CALL ID: 0
"What is the weather like in Brooklyn, New York?".to_string(), "What is the weather like in Brooklyn, New York?".to_string(),
), ),
}, },
tool_call_id: None,
}, },
]; ];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
@ -1299,6 +1310,7 @@ TOOL CALL ID: 0
text: "You are a helpful assistant.".to_string(), text: "You are a helpful assistant.".to_string(),
}]), }]),
}, },
tool_call_id: None,
}, },
Message { Message {
name: None, name: None,
@ -1326,6 +1338,7 @@ TOOL CALL ID: 0
}, },
]), ]),
}, },
tool_call_id: None,
}, },
]; ];

View File

@ -34,6 +34,7 @@ impl ToolGrammar {
.chain(std::iter::once(Tool { .chain(std::iter::once(Tool {
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
id: None,
name: "no_tool".to_string(), name: "no_tool".to_string(),
description: Some( description: Some(
"Open ended response with no specific tool selected".to_string(), "Open ended response with no specific tool selected".to_string(),

View File

@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokenizers::Encoding; use tokenizers::Encoding;
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
@ -919,7 +920,10 @@ pub(crate) struct ChatRequest {
} }
impl ChatRequest { impl ChatRequest {
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> { fn try_into_generate(
self,
infer: &Infer,
) -> Result<(GenerateRequest, Option<HashMap<String, String>>), InferError> {
let ChatRequest { let ChatRequest {
model, model,
max_tokens, max_tokens,
@ -959,7 +963,7 @@ impl ChatRequest {
let (inputs, grammar, using_tools) = match response_format { let (inputs, grammar, using_tools) = match response_format {
Some(format) => { Some(format) => {
let inputs = infer.apply_chat_template(messages, None)?; let inputs = infer.apply_chat_template(messages, None)?;
(inputs, Some(format), false) (inputs, Some(format), None)
} }
None => { None => {
if let Some(tools) = tools { if let Some(tools) = tools {
@ -968,20 +972,31 @@ impl ChatRequest {
let grammar = GrammarType::Json(serde_json::json!(tool_schema)); let grammar = GrammarType::Json(serde_json::json!(tool_schema));
let inputs: String = infer.apply_chat_template( let inputs: String = infer.apply_chat_template(
messages, messages,
Some((updated_tools, tool_prompt)), Some((updated_tools.clone(), tool_prompt)),
)?; )?;
(inputs, Some(grammar), true) let tool_name_to_id: HashMap<String, String> = updated_tools
.into_iter()
.map(|tool| {
(
tool.function.name,
tool.function
.id
.map_or_else(|| "0".to_string(), |id| id.to_string()),
)
})
.collect();
(inputs, Some(grammar), Some(tool_name_to_id))
} }
None => { None => {
// same as if no response_format or tools are set // same as if no response_format or tools are set
let inputs = infer.apply_chat_template(messages, None)?; let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false) (inputs, None, None)
} }
} }
} else { } else {
// if no response_format or tools are set simply apply the chat template to generate inputs // if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(messages, None)?; let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false) (inputs, None, None)
} }
} }
}; };
@ -1161,6 +1176,8 @@ pub struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
pub name: String, pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(alias = "parameters", serialize_with = "serialize_as_string")] #[serde(alias = "parameters", serialize_with = "serialize_as_string")]
pub arguments: serde_json::Value, pub arguments: serde_json::Value,
} }
@ -1182,7 +1199,7 @@ pub(crate) struct Tool {
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<TextMessage>, messages: Vec<TextMessage>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
@ -1215,6 +1232,9 @@ pub enum MessageChunk {
pub struct Message { pub struct Message {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "10")]
pub tool_call_id: Option<String>,
#[serde(flatten)] #[serde(flatten)]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
pub body: MessageBody, pub body: MessageBody,
@ -1294,7 +1314,7 @@ impl From<Message> for TextMessage {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(""), .join(""),
}, },
..Default::default() tool_call_id: value.tool_call_id,
} }
} }
} }
@ -1631,6 +1651,7 @@ mod tests {
body: MessageBody::Content { body: MessageBody::Content {
content: MessageContent::SingleText("What is Deep Learning?".to_string()) content: MessageContent::SingleText("What is Deep Learning?".to_string())
}, },
tool_call_id: None,
} }
); );
} }
@ -1690,6 +1711,7 @@ mod tests {
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
]), ]),
}, },
tool_call_id: None,
} }
); );
} }
@ -1704,7 +1726,8 @@ mod tests {
MessageChunk::Text { text: "Whats in this image?".to_string() }, MessageChunk::Text { text: "Whats in this image?".to_string() },
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
]), ]),
} },
tool_call_id: None
}; };
let textmsg: TextMessage = message.into(); let textmsg: TextMessage = message.into();
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
@ -1765,6 +1788,7 @@ mod tests {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
id: None,
description: None, description: None,
name: "myfn".to_string(), name: "myfn".to_string(),
arguments: json!({ arguments: json!({

View File

@ -1165,8 +1165,7 @@ pub(crate) async fn chat_completions(
tracing::debug!("Got chat_template {:?}", infer.chat_template); tracing::debug!("Got chat_template {:?}", infer.chat_template);
let id = chat.next_tool_call_id(); let id = chat.next_tool_call_id();
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
chat.clone().try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters)); span.record("parameters", format!("{:?}", generate_request.parameters));
let logprobs = logprobs.unwrap_or_default(); let logprobs = logprobs.unwrap_or_default();
@ -1188,7 +1187,7 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); let mut state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result{ match result{
Ok(stream_token) => { Ok(stream_token) => {
@ -1197,12 +1196,12 @@ pub(crate) async fn chat_completions(
ChatEvent::NoTool => { ChatEvent::NoTool => {
chat.tools = None; chat.tools = None;
chat.response_format = None; chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools) =
chat.clone().try_into_generate(&infer).unwrap(); chat.clone().try_into_generate(&infer).unwrap();
assert!(!using_tools); assert!(using_tools.is_none());
let (_headers, response_stream2) = let (_headers, response_stream2) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
response_stream = Box::pin(response_stream2); response_stream = Box::pin(response_stream2);
} }
ChatEvent::Events(events) => { ChatEvent::Events(events) => {
@ -1237,14 +1236,13 @@ pub(crate) async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let (tool_calls, output) = if using_tools { let (tool_calls, output) = if using_tools.is_some() {
match crate::chat::parse_output(&generation.generated_text)? { match crate::chat::parse_output(&generation.generated_text)? {
ChatChoice::NoTool => { ChatChoice::NoTool => {
chat.tools = None; chat.tools = None;
chat.response_format = None; chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
chat.clone().try_into_generate(&infer)?; assert!(using_tools.is_none());
assert!(!using_tools);
let (headers_final, input_length_final, Json(generation)) = generate_internal( let (headers_final, input_length_final, Json(generation)) = generate_internal(
Extension(infer), Extension(infer),
compute_type, compute_type,
@ -1256,7 +1254,16 @@ pub(crate) async fn chat_completions(
input_length = input_length_final; input_length = input_length_final;
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
} }
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None), ChatChoice::ToolCalls(mut tool_calls) => {
// assign the tool ids based on the tool names
tool_calls.iter_mut().for_each(|tool_call| {
tool_call.id = using_tools
.as_ref()
.and_then(|tools| tools.get(&tool_call.function.name))
.map_or("0".to_string(), |id| id.clone());
});
(Some(tool_calls), None)
}
} }
} else { } else {
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))

View File

@ -104,8 +104,7 @@ pub(crate) async fn vertex_compatibility(
}, },
}, },
VertexInstance::Chat(instance) => { VertexInstance::Chat(instance) => {
let (generate_request, _using_tools): (GenerateRequest, bool) = let (generate_request, _using_tools) = instance.try_into_generate(&infer)?;
instance.try_into_generate(&infer)?;
generate_request generate_request
} }
}; };
@ -176,6 +175,7 @@ mod tests {
"What's Deep Learning?".to_string() "What's Deep Learning?".to_string()
) )
}, },
tool_call_id: None,
},], },],
max_tokens: Some(128), max_tokens: Some(128),
top_p: Some(0.95), top_p: Some(0.95),