make content field optional in Message for role=assistant

This commit is contained in:
Andrew Reed 2025-02-18 22:31:56 +00:00
parent 3b096626e8
commit 3eef08b7a1
2 changed files with 34 additions and 6 deletions

View File

@ -1,5 +1,5 @@
use crate::infer::InferError;
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use crate::{ChatTemplateInputs, Message, MessageChunk, MessageContent, TextMessage, TokenizerConfigToken, Tool};
use chrono::Local;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
@ -74,14 +74,38 @@ impl ChatTemplate {
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text });
if let Some(ref mut content) = last_message.content {
content.push(MessageChunk::Text { text });
} else {
last_message.content = Some(MessageContent::SingleText(text));
}
}
Some(tools)
}
None => None,
};
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
let messages: Vec<TextMessage> = messages
.into_iter()
.map(|m| {
if m.role == "assistant" && m.tool_calls.is_some() && m.content.is_none() {
// For assistant messages with tool calls but no content,
// just use the standard conversion which will handle None content
Ok(m.into())
} else if m.content.is_none() {
// For messages requiring content but having none, return error
return Err(InferError::TemplateError(
minijinja::Error::new(
minijinja::ErrorKind::SyntaxError,
"Content is required for this message type",
)
));
} else {
Ok(m.into())
}
})
.collect::<Result<_, _>>()?;
let final_message = messages.last().cloned();
let mut rendered_template = self
.template

View File

@ -1189,10 +1189,13 @@ pub struct Message {
#[schema(example = "user")]
role: String,
#[schema(example = "My name is David and I")]
pub content: MessageContent,
#[serde(default, skip_serializing_if = "Option::is_none")]
content: Option<MessageContent>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")]
name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
@ -1232,8 +1235,8 @@ impl From<Message> for TextMessage {
TextMessage {
role: value.role,
content: match value.content {
MessageContent::SingleText(text) => text,
MessageContent::MultipleChunks(chunks) => chunks
Some(MessageContent::SingleText(text)) => text,
Some(MessageContent::MultipleChunks(chunks)) => chunks
.into_iter()
.map(|chunk| match chunk {
MessageChunk::Text { text } => text,
@ -1241,6 +1244,7 @@ impl From<Message> for TextMessage {
})
.collect::<Vec<_>>()
.join(""),
None => String::new(),
},
}
}