mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Merge e721574729
into 0b28aabb94
This commit is contained in:
commit
34a3d09bb8
@ -1520,6 +1520,10 @@
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
@ -1883,6 +1887,11 @@
|
|||||||
"role": {
|
"role": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "user"
|
"example": "user"
|
||||||
|
},
|
||||||
|
"tool_call_id": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "10",
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,6 +49,7 @@ pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferErro
|
|||||||
id: "0".to_string(),
|
id: "0".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
|
id: None,
|
||||||
description: None,
|
description: None,
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
arguments: serde_json::to_value(call.function.arguments).map_err(|err| {
|
arguments: serde_json::to_value(call.function.arguments).map_err(|err| {
|
||||||
|
@ -96,15 +96,21 @@ impl ChatTemplate {
|
|||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
let final_message = messages.last().cloned();
|
let final_message = messages.last().cloned();
|
||||||
let mut rendered_template = self
|
let template_inputs = ChatTemplateInputs {
|
||||||
.template
|
|
||||||
.render(ChatTemplateInputs {
|
|
||||||
messages,
|
messages,
|
||||||
bos_token: self.bos_token.as_deref(),
|
bos_token: self.bos_token.as_deref(),
|
||||||
eos_token: self.eos_token.as_deref(),
|
eos_token: self.eos_token.as_deref(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools,
|
tools,
|
||||||
})
|
};
|
||||||
|
|
||||||
|
// NOTE: initalizing `template_inputs` is helpful when JSON dumping the
|
||||||
|
// `ChatTemplateInputs` struct for debugging
|
||||||
|
// let template_inputs_as_json = serde_json::to_string(&template_inputs).unwrap();
|
||||||
|
|
||||||
|
let mut rendered_template = self
|
||||||
|
.template
|
||||||
|
.render(template_inputs)
|
||||||
.map_err(InferError::TemplateError)?;
|
.map_err(InferError::TemplateError)?;
|
||||||
|
|
||||||
// if the last message is from the assistant, continue the generation prompt
|
// if the last message is from the assistant, continue the generation prompt
|
||||||
@ -1175,6 +1181,7 @@ TOOL CALL ID: 0
|
|||||||
"I'd like to show off how chat templating works!".to_string(),
|
"I'd like to show off how chat templating works!".to_string(),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
@ -1184,6 +1191,7 @@ TOOL CALL ID: 0
|
|||||||
"Great! How can I help you today?".to_string(),
|
"Great! How can I help you today?".to_string(),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
@ -1191,6 +1199,7 @@ TOOL CALL ID: 0
|
|||||||
body: MessageBody::Content {
|
body: MessageBody::Content {
|
||||||
content: MessageContent::SingleText("Just testing".to_string()),
|
content: MessageContent::SingleText("Just testing".to_string()),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
@ -1220,6 +1229,7 @@ TOOL CALL ID: 0
|
|||||||
.to_string(),
|
.to_string(),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
@ -1229,6 +1239,7 @@ TOOL CALL ID: 0
|
|||||||
"What is the weather like in Brooklyn, New York?".to_string(),
|
"What is the weather like in Brooklyn, New York?".to_string(),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
@ -1299,6 +1310,7 @@ TOOL CALL ID: 0
|
|||||||
text: "You are a helpful assistant.".to_string(),
|
text: "You are a helpful assistant.".to_string(),
|
||||||
}]),
|
}]),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
name: None,
|
name: None,
|
||||||
@ -1326,6 +1338,7 @@ TOOL CALL ID: 0
|
|||||||
},
|
},
|
||||||
]),
|
]),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ impl ToolGrammar {
|
|||||||
.chain(std::iter::once(Tool {
|
.chain(std::iter::once(Tool {
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
|
id: None,
|
||||||
name: "no_tool".to_string(),
|
name: "no_tool".to_string(),
|
||||||
description: Some(
|
description: Some(
|
||||||
"Open ended response with no specific tool selected".to_string(),
|
"Open ended response with no specific tool selected".to_string(),
|
||||||
|
@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
|
|||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::IntoPyDict;
|
use pyo3::types::IntoPyDict;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
use tokenizers::Encoding;
|
use tokenizers::Encoding;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
@ -919,7 +920,10 @@ pub(crate) struct ChatRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ChatRequest {
|
impl ChatRequest {
|
||||||
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
|
fn try_into_generate(
|
||||||
|
self,
|
||||||
|
infer: &Infer,
|
||||||
|
) -> Result<(GenerateRequest, Option<HashMap<String, String>>), InferError> {
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
model,
|
model,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
@ -959,7 +963,7 @@ impl ChatRequest {
|
|||||||
let (inputs, grammar, using_tools) = match response_format {
|
let (inputs, grammar, using_tools) = match response_format {
|
||||||
Some(format) => {
|
Some(format) => {
|
||||||
let inputs = infer.apply_chat_template(messages, None)?;
|
let inputs = infer.apply_chat_template(messages, None)?;
|
||||||
(inputs, Some(format), false)
|
(inputs, Some(format), None)
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
if let Some(tools) = tools {
|
if let Some(tools) = tools {
|
||||||
@ -968,20 +972,31 @@ impl ChatRequest {
|
|||||||
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
|
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
|
||||||
let inputs: String = infer.apply_chat_template(
|
let inputs: String = infer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
Some((updated_tools, tool_prompt)),
|
Some((updated_tools.clone(), tool_prompt)),
|
||||||
)?;
|
)?;
|
||||||
(inputs, Some(grammar), true)
|
let tool_name_to_id: HashMap<String, String> = updated_tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| {
|
||||||
|
(
|
||||||
|
tool.function.name,
|
||||||
|
tool.function
|
||||||
|
.id
|
||||||
|
.map_or_else(|| "0".to_string(), |id| id.to_string()),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(inputs, Some(grammar), Some(tool_name_to_id))
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// same as if no response_format or tools are set
|
// same as if no response_format or tools are set
|
||||||
let inputs = infer.apply_chat_template(messages, None)?;
|
let inputs = infer.apply_chat_template(messages, None)?;
|
||||||
(inputs, None, false)
|
(inputs, None, None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// if no response_format or tools are set simply apply the chat template to generate inputs
|
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||||
let inputs = infer.apply_chat_template(messages, None)?;
|
let inputs = infer.apply_chat_template(messages, None)?;
|
||||||
(inputs, None, false)
|
(inputs, None, None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1161,6 +1176,8 @@ pub struct FunctionDefinition {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
#[serde(alias = "parameters", serialize_with = "serialize_as_string")]
|
#[serde(alias = "parameters", serialize_with = "serialize_as_string")]
|
||||||
pub arguments: serde_json::Value,
|
pub arguments: serde_json::Value,
|
||||||
}
|
}
|
||||||
@ -1182,7 +1199,7 @@ pub(crate) struct Tool {
|
|||||||
pub function: FunctionDefinition,
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, Default)]
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
pub(crate) struct ChatTemplateInputs<'a> {
|
pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
messages: Vec<TextMessage>,
|
messages: Vec<TextMessage>,
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
@ -1215,6 +1232,9 @@ pub enum MessageChunk {
|
|||||||
pub struct Message {
|
pub struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
#[schema(example = "10")]
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub body: MessageBody,
|
pub body: MessageBody,
|
||||||
@ -1294,7 +1314,7 @@ impl From<Message> for TextMessage {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
},
|
},
|
||||||
..Default::default()
|
tool_call_id: value.tool_call_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1631,6 +1651,7 @@ mod tests {
|
|||||||
body: MessageBody::Content {
|
body: MessageBody::Content {
|
||||||
content: MessageContent::SingleText("What is Deep Learning?".to_string())
|
content: MessageContent::SingleText("What is Deep Learning?".to_string())
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -1690,6 +1711,7 @@ mod tests {
|
|||||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
|
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
|
||||||
]),
|
]),
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -1704,7 +1726,8 @@ mod tests {
|
|||||||
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
|
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
|
||||||
]),
|
]),
|
||||||
}
|
},
|
||||||
|
tool_call_id: None
|
||||||
};
|
};
|
||||||
let textmsg: TextMessage = message.into();
|
let textmsg: TextMessage = message.into();
|
||||||
assert_eq!(textmsg.content, "Whats in this image?");
|
assert_eq!(textmsg.content, "Whats in this image?");
|
||||||
@ -1765,6 +1788,7 @@ mod tests {
|
|||||||
id: "0".to_string(),
|
id: "0".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
|
id: None,
|
||||||
description: None,
|
description: None,
|
||||||
name: "myfn".to_string(),
|
name: "myfn".to_string(),
|
||||||
arguments: json!({
|
arguments: json!({
|
||||||
|
@ -1165,8 +1165,7 @@ pub(crate) async fn chat_completions(
|
|||||||
|
|
||||||
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
||||||
let id = chat.next_tool_call_id();
|
let id = chat.next_tool_call_id();
|
||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
|
||||||
chat.clone().try_into_generate(&infer)?;
|
|
||||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||||
let logprobs = logprobs.unwrap_or_default();
|
let logprobs = logprobs.unwrap_or_default();
|
||||||
|
|
||||||
@ -1188,7 +1187,7 @@ pub(crate) async fn chat_completions(
|
|||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
let mut state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
match result{
|
match result{
|
||||||
Ok(stream_token) => {
|
Ok(stream_token) => {
|
||||||
@ -1197,12 +1196,12 @@ pub(crate) async fn chat_completions(
|
|||||||
ChatEvent::NoTool => {
|
ChatEvent::NoTool => {
|
||||||
chat.tools = None;
|
chat.tools = None;
|
||||||
chat.response_format = None;
|
chat.response_format = None;
|
||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools) =
|
||||||
chat.clone().try_into_generate(&infer).unwrap();
|
chat.clone().try_into_generate(&infer).unwrap();
|
||||||
assert!(!using_tools);
|
assert!(using_tools.is_none());
|
||||||
let (_headers, response_stream2) =
|
let (_headers, response_stream2) =
|
||||||
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
|
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
|
||||||
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
||||||
response_stream = Box::pin(response_stream2);
|
response_stream = Box::pin(response_stream2);
|
||||||
}
|
}
|
||||||
ChatEvent::Events(events) => {
|
ChatEvent::Events(events) => {
|
||||||
@ -1237,14 +1236,13 @@ pub(crate) async fn chat_completions(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if using_tools {
|
let (tool_calls, output) = if using_tools.is_some() {
|
||||||
match crate::chat::parse_output(&generation.generated_text)? {
|
match crate::chat::parse_output(&generation.generated_text)? {
|
||||||
ChatChoice::NoTool => {
|
ChatChoice::NoTool => {
|
||||||
chat.tools = None;
|
chat.tools = None;
|
||||||
chat.response_format = None;
|
chat.response_format = None;
|
||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
|
||||||
chat.clone().try_into_generate(&infer)?;
|
assert!(using_tools.is_none());
|
||||||
assert!(!using_tools);
|
|
||||||
let (headers_final, input_length_final, Json(generation)) = generate_internal(
|
let (headers_final, input_length_final, Json(generation)) = generate_internal(
|
||||||
Extension(infer),
|
Extension(infer),
|
||||||
compute_type,
|
compute_type,
|
||||||
@ -1256,7 +1254,16 @@ pub(crate) async fn chat_completions(
|
|||||||
input_length = input_length_final;
|
input_length = input_length_final;
|
||||||
(None, Some(generation.generated_text))
|
(None, Some(generation.generated_text))
|
||||||
}
|
}
|
||||||
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
|
ChatChoice::ToolCalls(mut tool_calls) => {
|
||||||
|
// assign the tool ids based on the tool names
|
||||||
|
tool_calls.iter_mut().for_each(|tool_call| {
|
||||||
|
tool_call.id = using_tools
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|tools| tools.get(&tool_call.function.name))
|
||||||
|
.map_or("0".to_string(), |id| id.clone());
|
||||||
|
});
|
||||||
|
(Some(tool_calls), None)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
(None, Some(generation.generated_text))
|
(None, Some(generation.generated_text))
|
||||||
|
@ -104,8 +104,7 @@ pub(crate) async fn vertex_compatibility(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
VertexInstance::Chat(instance) => {
|
VertexInstance::Chat(instance) => {
|
||||||
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
let (generate_request, _using_tools) = instance.try_into_generate(&infer)?;
|
||||||
instance.try_into_generate(&infer)?;
|
|
||||||
generate_request
|
generate_request
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -176,6 +175,7 @@ mod tests {
|
|||||||
"What's Deep Learning?".to_string()
|
"What's Deep Learning?".to_string()
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
tool_call_id: None,
|
||||||
},],
|
},],
|
||||||
max_tokens: Some(128),
|
max_tokens: Some(128),
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
|
Loading…
Reference in New Issue
Block a user