feat: prefer custom deserializer for complex message content

This commit is contained in:
drbh 2024-04-29 16:07:53 +00:00
parent a480273047
commit 91a705e1a9
2 changed files with 749 additions and 737 deletions

File diff suppressed because it is too large Load Diff

View File

@ -525,7 +525,7 @@ impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
_output: Option<String>, output: Option<String>,
created: u64, created: u64,
details: Details, details: Details,
return_logprobs: bool, return_logprobs: bool,
@ -541,7 +541,7 @@ impl ChatCompletion {
index: 0, index: 0,
message: Message { message: Message {
role: "assistant".into(), role: "assistant".into(),
content: None, content: output,
name: None, name: None,
tool_calls, tool_calls,
}, },
@ -878,7 +878,7 @@ pub(crate) struct SerializedMessage {
#[derive(Clone, Serialize, Deserialize, Default)] #[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<SerializedMessage>, messages: Vec<Message>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
@ -914,13 +914,55 @@ pub(crate) struct Content {
pub image_url: Option<ImageUrl>, pub image_url: Option<ImageUrl>,
} }
mod message_content_serde {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(s)),
Value::Array(arr) => {
let results: Result<Vec<String>, _> = arr
.into_iter()
.map(|v| {
let content: Content =
serde_json::from_value(v).map_err(de::Error::custom)?;
match content.r#type.as_str() {
"text" => Ok(content.text.unwrap_or_default()),
"image_url" => {
if let Some(url) = content.image_url {
Ok(format!("\n![]({})", url.url))
} else {
Ok(String::new())
}
}
_ => Err(de::Error::custom("invalid content type")),
}
})
.collect();
results.map(|strings| Some(strings.join("")))
}
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
pub content: Option<Vec<Content>>, #[serde(deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,