mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: prefer custom deserializer for complex message content
This commit is contained in:
parent
a480273047
commit
91a705e1a9
1436
router/src/infer.rs
1436
router/src/infer.rs
File diff suppressed because it is too large
Load Diff
@ -525,7 +525,7 @@ impl ChatCompletion {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
_output: Option<String>,
|
output: Option<String>,
|
||||||
created: u64,
|
created: u64,
|
||||||
details: Details,
|
details: Details,
|
||||||
return_logprobs: bool,
|
return_logprobs: bool,
|
||||||
@ -541,7 +541,7 @@ impl ChatCompletion {
|
|||||||
index: 0,
|
index: 0,
|
||||||
message: Message {
|
message: Message {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: None,
|
content: output,
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
},
|
},
|
||||||
@ -878,7 +878,7 @@ pub(crate) struct SerializedMessage {
|
|||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, Default)]
|
#[derive(Clone, Serialize, Deserialize, Default)]
|
||||||
pub(crate) struct ChatTemplateInputs<'a> {
|
pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
messages: Vec<SerializedMessage>,
|
messages: Vec<Message>,
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
@ -914,13 +914,55 @@ pub(crate) struct Content {
|
|||||||
pub image_url: Option<ImageUrl>,
|
pub image_url: Option<ImageUrl>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod message_content_serde {
|
||||||
|
use super::*;
|
||||||
|
use serde::de;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(Some(s)),
|
||||||
|
Value::Array(arr) => {
|
||||||
|
let results: Result<Vec<String>, _> = arr
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| {
|
||||||
|
let content: Content =
|
||||||
|
serde_json::from_value(v).map_err(de::Error::custom)?;
|
||||||
|
match content.r#type.as_str() {
|
||||||
|
"text" => Ok(content.text.unwrap_or_default()),
|
||||||
|
"image_url" => {
|
||||||
|
if let Some(url) = content.image_url {
|
||||||
|
Ok(format!("\n", url.url))
|
||||||
|
} else {
|
||||||
|
Ok(String::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(de::Error::custom("invalid content type")),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
results.map(|strings| Some(strings.join("")))
|
||||||
|
}
|
||||||
|
Value::Null => Ok(None),
|
||||||
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
|
||||||
pub(crate) struct Message {
|
pub(crate) struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: Option<Vec<Content>>,
|
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||||
|
pub content: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
Loading…
Reference in New Issue
Block a user