mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: leverage serde for conditional deser
This commit is contained in:
parent
4ba5e74efc
commit
d759a7f492
@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck;
|
||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
|
||||
};
|
||||
use crate::{
|
||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||
};
|
||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
@ -270,7 +272,11 @@ struct ChatTemplate {
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
fn new(
|
||||
template: String,
|
||||
bos_token: Option<TokenizerConfigToken>,
|
||||
eos_token: Option<TokenizerConfigToken>,
|
||||
) -> Self {
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
@ -287,8 +293,20 @@ impl ChatTemplate {
|
||||
|
||||
Self {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
bos_token: match bos_token {
|
||||
Some(token) => match token {
|
||||
TokenizerConfigToken::String(token) => Some(token),
|
||||
TokenizerConfigToken::Object { content } => Some(content),
|
||||
},
|
||||
None => None,
|
||||
},
|
||||
eos_token: match eos_token {
|
||||
Some(token) => match token {
|
||||
TokenizerConfigToken::String(token) => Some(token),
|
||||
TokenizerConfigToken::Object { content } => Some(content),
|
||||
},
|
||||
None => None,
|
||||
},
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
@ -301,9 +319,9 @@ impl ChatTemplate {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text(Text {
|
||||
last_message.content.push(MessageChunk::Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
}));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -340,6 +358,14 @@ impl ToolGrammar {
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
let tool = req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == function.name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
|
||||
.clone();
|
||||
vec![tool]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
|
@ -53,6 +53,8 @@ pub enum ChatTemplateVersions {
|
||||
Multiple(Vec<ChatTemplate>),
|
||||
}
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
@ -67,9 +69,26 @@ pub struct HubTokenizerConfig {
|
||||
}
|
||||
|
||||
impl HubTokenizerConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.and_then(|content| serde_json::from_str(&content).ok())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum TokenizerConfigToken {
|
||||
String(String),
|
||||
Object { content: String },
|
||||
}
|
||||
|
||||
impl From<TokenizerConfigToken> for String {
|
||||
fn from(token: TokenizerConfigToken) -> Self {
|
||||
match token {
|
||||
TokenizerConfigToken::String(s) => s,
|
||||
TokenizerConfigToken::Object { content } => content,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,9 +119,10 @@ pub struct HubProcessorConfig {
|
||||
}
|
||||
|
||||
impl HubProcessorConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.and_then(|content| serde_json::from_str(&content).ok())
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,35 +141,6 @@ pub(crate) enum GrammarType {
|
||||
Regex(String),
|
||||
}
|
||||
|
||||
mod token_serde {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(s)),
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
|
||||
Ok(Some(content.to_string()))
|
||||
} else {
|
||||
Err(de::Error::custom(
|
||||
"content key not found in structured token",
|
||||
))
|
||||
}
|
||||
}
|
||||
Value::Null => Ok(None),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
@ -359,30 +350,33 @@ fn default_parameters() -> GenerateParameters {
|
||||
}
|
||||
}
|
||||
|
||||
mod prompt_serde {
|
||||
use serde::{self, Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[serde(try_from = "PromptDeserializer")]
|
||||
pub struct Prompt(pub Vec<String>);
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum PromptDeserializer {
|
||||
Single(String),
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
impl TryFrom<PromptDeserializer> for Prompt {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: PromptDeserializer) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
Value::String(s) => Ok(vec![s]),
|
||||
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
|
||||
"Empty array detected. Do not use an empty array for the prompt.",
|
||||
)),
|
||||
Value::Array(arr) => arr
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
Value::String(s) => Ok(s.to_owned()),
|
||||
_ => Err(serde::de::Error::custom("Expected a string")),
|
||||
})
|
||||
.collect(),
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"Expected a string or an array of strings",
|
||||
)),
|
||||
PromptDeserializer::Single(s) => Ok(Prompt(vec![s])),
|
||||
PromptDeserializer::Multiple(v) => {
|
||||
if v.is_empty() {
|
||||
Err(
|
||||
"Empty array detected. Do not use an empty array for the prompt."
|
||||
.to_string(),
|
||||
)
|
||||
} else {
|
||||
Ok(Prompt(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -396,8 +390,7 @@ pub struct CompletionRequest {
|
||||
|
||||
/// The prompt to generate completions for.
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
#[serde(deserialize_with = "prompt_serde::deserialize")]
|
||||
pub prompt: Vec<String>,
|
||||
pub prompt: Prompt,
|
||||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
@ -824,7 +817,6 @@ pub(crate) struct ChatRequest {
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
@ -840,44 +832,41 @@ fn default_tool_prompt() -> Option<String> {
|
||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||
)
|
||||
}
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
enum ToolType {
|
||||
FunctionName(String),
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolType {
|
||||
OneOf,
|
||||
FunctionName(String),
|
||||
Function { function: FunctionName },
|
||||
}
|
||||
|
||||
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
||||
mod deserialize_tool_choice {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct FunctionName {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(from = "ToolTypeDeserializer")]
|
||||
pub struct ToolChoice(pub Option<ToolType>);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ToolTypeDeserializer {
|
||||
None(Option<String>),
|
||||
Some(ToolType),
|
||||
}
|
||||
|
||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||
fn from(value: ToolTypeDeserializer) -> Self {
|
||||
match value {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"none" => Ok(None),
|
||||
"auto" => Ok(Some(ToolType::OneOf)),
|
||||
_ => Ok(Some(ToolType::FunctionName(s))),
|
||||
ToolTypeDeserializer::None(opt) => match opt.as_deref() {
|
||||
Some("none") => ToolChoice(None),
|
||||
Some("auto") => ToolChoice(Some(ToolType::OneOf)),
|
||||
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
|
||||
None => ToolChoice(Some(ToolType::OneOf)),
|
||||
},
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map
|
||||
.get("function")
|
||||
.and_then(|v| v.get("name"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
Ok(Some(ToolType::FunctionName(content.to_string())))
|
||||
} else {
|
||||
Err(de::Error::custom("function key not found in tool choice"))
|
||||
}
|
||||
}
|
||||
Value::Null => Ok(Some(ToolType::OneOf)),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -953,26 +942,16 @@ pub(crate) struct ToolCall {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct Url {
|
||||
pub struct Url {
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct ImageUrl {
|
||||
image_url: Url,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct Text {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum MessageChunk {
|
||||
Text(Text),
|
||||
ImageUrl(ImageUrl),
|
||||
pub enum MessageChunk {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: Url },
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
@ -980,35 +959,31 @@ pub struct Message {
|
||||
#[schema(example = "user")]
|
||||
role: String,
|
||||
#[schema(example = "My name is David and I")]
|
||||
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||
content: Vec<MessageChunk>,
|
||||
pub content: MessageContent,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
mod message_content_serde {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
SingleText(String),
|
||||
MultipleChunks(Vec<MessageChunk>),
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Message {
|
||||
Text(String),
|
||||
Chunks(Vec<MessageChunk>),
|
||||
// Pushing a chunk to a single text message will convert it to a multiple chunks message
|
||||
impl MessageContent {
|
||||
pub fn push(&mut self, chunk: MessageChunk) {
|
||||
match self {
|
||||
MessageContent::SingleText(text) => {
|
||||
*self =
|
||||
MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]);
|
||||
}
|
||||
MessageContent::MultipleChunks(chunks) => {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
let message: Message = Deserialize::deserialize(deserializer)?;
|
||||
let chunks = match message {
|
||||
Message::Text(text) => {
|
||||
vec![MessageChunk::Text(Text { text })]
|
||||
}
|
||||
Message::Chunks(s) => s,
|
||||
};
|
||||
Ok(chunks)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1024,18 +999,17 @@ impl From<Message> for TextMessage {
|
||||
fn from(value: Message) -> Self {
|
||||
TextMessage {
|
||||
role: value.role,
|
||||
content: value
|
||||
.content
|
||||
content: match value.content {
|
||||
MessageContent::SingleText(text) => text,
|
||||
MessageContent::MultipleChunks(chunks) => chunks
|
||||
.into_iter()
|
||||
.map(|c| match c {
|
||||
MessageChunk::Text(Text { text }) => text,
|
||||
MessageChunk::ImageUrl(image) => {
|
||||
let url = image.image_url.url;
|
||||
format!("")
|
||||
}
|
||||
.map(|chunk| match chunk {
|
||||
MessageChunk::Text { text } => text,
|
||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1243,9 +1217,16 @@ mod tests {
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
Some(TokenizerConfigToken::String(
|
||||
"<|begin▁of▁sentence|>".to_string()
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
config.eos_token,
|
||||
Some(TokenizerConfigToken::String(
|
||||
"<|end▁of▁sentence|>".to_string()
|
||||
))
|
||||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
|
||||
// in this case we expect the tokens to be encoded as structured tokens
|
||||
// we want the content of the structured token
|
||||
@ -1278,9 +1259,16 @@ mod tests {
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
Some(TokenizerConfigToken::Object {
|
||||
content: "<|begin▁of▁sentence|>".to_string()
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
config.eos_token,
|
||||
Some(TokenizerConfigToken::Object {
|
||||
content: "<|end▁of▁sentence|>".to_string()
|
||||
})
|
||||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1298,9 +1286,7 @@ mod tests {
|
||||
request.messages[0],
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![MessageChunk::Text(Text {
|
||||
text: "What is Deep Learning?".to_string()
|
||||
}),],
|
||||
content: MessageContent::SingleText("What is Deep Learning?".to_string()),
|
||||
name: None
|
||||
}
|
||||
);
|
||||
@ -1324,10 +1310,10 @@ mod tests {
|
||||
request.messages[0],
|
||||
Message{
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
||||
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
||||
],
|
||||
content: MessageContent::MultipleChunks(vec![
|
||||
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
|
||||
]),
|
||||
name: None
|
||||
}
|
||||
);
|
||||
@ -1337,10 +1323,10 @@ mod tests {
|
||||
fn text_message_convert() {
|
||||
let message = Message{
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
||||
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
||||
],
|
||||
content: MessageContent::MultipleChunks(vec![
|
||||
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
|
||||
]),
|
||||
name: None
|
||||
};
|
||||
let textmsg: TextMessage = message.into();
|
||||
|
@ -636,7 +636,7 @@ async fn completions(
|
||||
));
|
||||
}
|
||||
|
||||
if req.prompt.len() > info.max_client_batch_size {
|
||||
if req.prompt.0.len() > info.max_client_batch_size {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
@ -652,6 +652,7 @@ async fn completions(
|
||||
|
||||
let generate_requests: Vec<GenerateRequest> = req
|
||||
.prompt
|
||||
.0
|
||||
.iter()
|
||||
.map(|prompt| GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
|
Loading…
Reference in New Issue
Block a user