mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing types.
This commit is contained in:
parent
d8402eaf67
commit
2a87dd7274
@ -2,7 +2,7 @@
|
|||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
|
||||||
};
|
};
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
@ -362,16 +362,15 @@ impl ChatTemplate {
|
|||||||
if self.use_default_tool_template {
|
if self.use_default_tool_template {
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||||
last_message.content = Some(format!(
|
last_message.content.push(MessageChunk::Text(Text {
|
||||||
"{}\n---\n{}\n{}",
|
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||||
last_message.content.as_deref().unwrap_or_default(),
|
}));
|
||||||
tool_prompt,
|
|
||||||
tools
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
self.template
|
self.template
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
messages,
|
messages,
|
||||||
@ -939,8 +938,7 @@ impl InferError {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::infer::raise_exception;
|
use crate::infer::raise_exception;
|
||||||
use crate::ChatTemplateInputs;
|
use crate::{ChatTemplateInputs, TextMessage};
|
||||||
use crate::Message;
|
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -974,33 +972,21 @@ mod tests {
|
|||||||
|
|
||||||
let chat_template_inputs = ChatTemplateInputs {
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("Hi!".to_string()),
|
content: "Hi!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("Hello how can I help?".to_string()),
|
content: "Hello how can I help?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("What is Deep Learning?".to_string()),
|
content: "What is Deep Learning?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("magic!".to_string()),
|
content: "magic!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -1048,40 +1034,25 @@ mod tests {
|
|||||||
|
|
||||||
let chat_template_inputs = ChatTemplateInputs {
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("Hi!".to_string()),
|
content: "Hi!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("Hi again!".to_string()),
|
content: "Hi again!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("Hello how can I help?".to_string()),
|
content: "Hello how can I help?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("What is Deep Learning?".to_string()),
|
content: "What is Deep Learning?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("magic!".to_string()),
|
content: "magic!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -1134,33 +1105,21 @@ mod tests {
|
|||||||
|
|
||||||
let chat_template_inputs = ChatTemplateInputs {
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("Hi!".to_string()),
|
content: "Hi!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("Hello how can I help?".to_string()),
|
content: "Hello how can I help?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("What is Deep Learning?".to_string()),
|
content: "What is Deep Learning?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("magic!".to_string()),
|
content: "magic!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -1197,33 +1156,21 @@ mod tests {
|
|||||||
|
|
||||||
let chat_template_inputs = ChatTemplateInputs {
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("Hi!".to_string()),
|
content: "Hi!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("Hello how can I help?".to_string()),
|
content: "Hello how can I help?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("What is Deep Learning?".to_string()),
|
content: "What is Deep Learning?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("magic!".to_string()),
|
content: "magic!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -1246,38 +1193,24 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_many_chat_templates() {
|
fn test_many_chat_templates() {
|
||||||
let example_chat = vec![
|
let example_chat = vec![
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("Hello, how are you?".to_string()),
|
content: "Hello, how are you?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: Some("I'm doing great. How can I help you today?".to_string()),
|
content: "I'm doing great. How can I help you today?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("I'd like to show off how chat templating works!".to_string()),
|
content: "I'd like to show off how chat templating works!".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
let example_chat_with_system = [Message {
|
let example_chat_with_system = [TextMessage {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: Some(
|
content: "You are a friendly chatbot who always responds in the style of a pirate"
|
||||||
"You are a friendly chatbot who always responds in the style of a pirate"
|
.to_string(),
|
||||||
.to_string(),
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
}]
|
}]
|
||||||
.iter()
|
.iter()
|
||||||
.chain(&example_chat)
|
.chain(&example_chat)
|
||||||
@ -1417,19 +1350,13 @@ mod tests {
|
|||||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||||
input: ChatTemplateInputs {
|
input: ChatTemplateInputs {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
TextMessage{
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
|
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
TextMessage{
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some("How many helicopters can a human eat in one sitting?".to_string()),
|
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
|
@ -440,7 +440,7 @@ pub(crate) struct ChatCompletion {
|
|||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionComplete {
|
pub(crate) struct ChatCompletionComplete {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: Message,
|
pub message: OutputMessage,
|
||||||
pub logprobs: Option<ChatCompletionLogprobs>,
|
pub logprobs: Option<ChatCompletionLogprobs>,
|
||||||
pub finish_reason: String,
|
pub finish_reason: String,
|
||||||
}
|
}
|
||||||
@ -533,6 +533,17 @@ impl ChatCompletion {
|
|||||||
return_logprobs: bool,
|
return_logprobs: bool,
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let message = match (output, tool_calls) {
|
||||||
|
(Some(output), None) => OutputMessage::ChatMessage(Message {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: vec![MessageChunk::Text(Text { text: output })],
|
||||||
|
name: None,
|
||||||
|
}),
|
||||||
|
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { tool_calls }),
|
||||||
|
_ => {
|
||||||
|
todo!("Implement error for invalid tool vs chat");
|
||||||
|
}
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "text_completion".into(),
|
object: "text_completion".into(),
|
||||||
@ -541,13 +552,7 @@ impl ChatCompletion {
|
|||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
choices: vec![ChatCompletionComplete {
|
choices: vec![ChatCompletionComplete {
|
||||||
index: 0,
|
index: 0,
|
||||||
message: Message {
|
message,
|
||||||
role: "assistant".into(),
|
|
||||||
content: output,
|
|
||||||
name: None,
|
|
||||||
tool_calls,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
|
||||||
logprobs: return_logprobs
|
logprobs: return_logprobs
|
||||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||||
finish_reason: details.finish_reason.to_string(),
|
finish_reason: details.finish_reason.to_string(),
|
||||||
@ -852,7 +857,7 @@ where
|
|||||||
state.end()
|
state.end()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
||||||
pub(crate) struct FunctionDefinition {
|
pub(crate) struct FunctionDefinition {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
@ -872,7 +877,7 @@ pub(crate) struct Tool {
|
|||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, Default)]
|
#[derive(Clone, Serialize, Deserialize, Default)]
|
||||||
pub(crate) struct ChatTemplateInputs<'a> {
|
pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
messages: Vec<Message>,
|
messages: Vec<TextMessage>,
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
@ -880,91 +885,112 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
tools_prompt: Option<&'a str>,
|
tools_prompt: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
||||||
pub(crate) struct ToolCall {
|
pub(crate) struct ToolCall {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
pub function: FunctionDefinition,
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
pub(crate) struct Text {
|
struct Url {
|
||||||
#[serde(default)]
|
url: String,
|
||||||
pub text: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
pub(crate) struct ImageUrl {
|
struct ImageUrl {
|
||||||
#[serde(default)]
|
image_url: Url,
|
||||||
pub url: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
pub(crate) struct Content {
|
struct Text {
|
||||||
pub r#type: String,
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
enum MessageChunk {
|
||||||
|
Text(Text),
|
||||||
|
ImageUrl(ImageUrl),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
|
pub struct Message {
|
||||||
|
#[schema(example = "user")]
|
||||||
|
role: String,
|
||||||
|
#[schema(example = "My name is David and I")]
|
||||||
|
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||||
|
content: Vec<MessageChunk>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub text: Option<String>,
|
#[schema(example = "\"David\"")]
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
name: Option<String>,
|
||||||
pub image_url: Option<ImageUrl>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mod message_content_serde {
|
mod message_content_serde {
|
||||||
use super::*;
|
use super::*;
|
||||||
use serde::de;
|
use serde::{Deserialize, Deserializer};
|
||||||
use serde::Deserializer;
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
let value = Value::deserialize(deserializer)?;
|
#[derive(Deserialize)]
|
||||||
match value {
|
#[serde(untagged)]
|
||||||
Value::String(s) => Ok(Some(s)),
|
enum Message {
|
||||||
Value::Array(arr) => {
|
Text(String),
|
||||||
let results: Result<Vec<String>, _> = arr
|
Chunks(Vec<MessageChunk>),
|
||||||
.into_iter()
|
}
|
||||||
.map(|v| {
|
let message: Message = Deserialize::deserialize(deserializer)?;
|
||||||
let content: Content =
|
let chunks = match message {
|
||||||
serde_json::from_value(v).map_err(de::Error::custom)?;
|
Message::Text(text) => {
|
||||||
match content.r#type.as_str() {
|
vec![MessageChunk::Text(Text { text })]
|
||||||
"text" => Ok(content.text.unwrap_or_default()),
|
|
||||||
"image_url" => {
|
|
||||||
if let Some(url) = content.image_url {
|
|
||||||
Ok(format!("", url.url))
|
|
||||||
} else {
|
|
||||||
Ok(String::new())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => Err(de::Error::custom("invalid content type")),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
results.map(|strings| Some(strings.join("")))
|
|
||||||
}
|
}
|
||||||
Value::Null => Ok(None),
|
Message::Chunks(s) => s,
|
||||||
_ => Err(de::Error::custom("invalid token format")),
|
};
|
||||||
|
Ok(chunks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
|
pub struct TextMessage {
|
||||||
|
#[schema(example = "user")]
|
||||||
|
pub role: String,
|
||||||
|
#[schema(example = "My name is David and I")]
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Message> for TextMessage {
|
||||||
|
fn from(value: Message) -> Self {
|
||||||
|
TextMessage {
|
||||||
|
role: value.role,
|
||||||
|
content: value
|
||||||
|
.content
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| match c {
|
||||||
|
MessageChunk::Text(Text { text }) => text,
|
||||||
|
MessageChunk::ImageUrl(image) => {
|
||||||
|
let url = image.image_url.url;
|
||||||
|
format!("")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(""),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
pub(crate) struct Message {
|
pub struct ToolCallMessage {
|
||||||
#[schema(example = "user")]
|
tool_calls: Vec<ToolCall>,
|
||||||
pub role: String,
|
tool_call_id: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
}
|
||||||
#[schema(example = "My name is David and I")]
|
|
||||||
#[serde(default, deserialize_with = "message_content_serde::deserialize")]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
pub content: Option<String>,
|
#[serde(untagged)]
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
pub(crate) enum OutputMessage {
|
||||||
#[schema(example = "\"David\"")]
|
ChatMessage(Message),
|
||||||
pub name: Option<String>,
|
ToolCall(ToolCallMessage),
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
#[schema(example = "\"get_weather\"")]
|
|
||||||
pub tool_call_id: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
@ -1127,7 +1153,7 @@ pub(crate) struct ErrorResponse {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||||
@ -1195,4 +1221,65 @@ mod tests {
|
|||||||
);
|
);
|
||||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_simple_string() {
|
||||||
|
let json = json!(
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user",
|
||||||
|
"content": "What is Deep Learning?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
request.messages[0],
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![MessageChunk::Text(Text {
|
||||||
|
text: "What is Deep Learning?".to_string()
|
||||||
|
}),],
|
||||||
|
name: None
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_request() {
|
||||||
|
let json = json!(
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Whats in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
request.messages[0],
|
||||||
|
Message{
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![
|
||||||
|
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
||||||
|
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
||||||
|
],
|
||||||
|
name: None
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user