mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
make content field optional in Message for role=assistant
This commit is contained in:
parent
3b096626e8
commit
3eef08b7a1
@ -1,5 +1,5 @@
|
||||
use crate::infer::InferError;
|
||||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||
use crate::{ChatTemplateInputs, Message, MessageChunk, MessageContent, TextMessage, TokenizerConfigToken, Tool};
|
||||
use chrono::Local;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
@ -74,14 +74,38 @@ impl ChatTemplate {
|
||||
format!("\n---\n{}", tool_prompt)
|
||||
};
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
last_message.content.push(MessageChunk::Text { text });
|
||||
if let Some(ref mut content) = last_message.content {
|
||||
content.push(MessageChunk::Text { text });
|
||||
} else {
|
||||
last_message.content = Some(MessageContent::SingleText(text));
|
||||
}
|
||||
}
|
||||
Some(tools)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
let messages: Vec<TextMessage> = messages
|
||||
.into_iter()
|
||||
.map(|m| {
|
||||
if m.role == "assistant" && m.tool_calls.is_some() && m.content.is_none() {
|
||||
// For assistant messages with tool calls but no content,
|
||||
// just use the standard conversion which will handle None content
|
||||
Ok(m.into())
|
||||
} else if m.content.is_none() {
|
||||
// For messages requiring content but having none, return error
|
||||
return Err(InferError::TemplateError(
|
||||
minijinja::Error::new(
|
||||
minijinja::ErrorKind::SyntaxError,
|
||||
"Content is required for this message type",
|
||||
)
|
||||
));
|
||||
} else {
|
||||
Ok(m.into())
|
||||
}
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let final_message = messages.last().cloned();
|
||||
let mut rendered_template = self
|
||||
.template
|
||||
|
@ -1189,10 +1189,13 @@ pub struct Message {
|
||||
#[schema(example = "user")]
|
||||
role: String,
|
||||
#[schema(example = "My name is David and I")]
|
||||
pub content: MessageContent,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
content: Option<MessageContent>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
@ -1232,8 +1235,8 @@ impl From<Message> for TextMessage {
|
||||
TextMessage {
|
||||
role: value.role,
|
||||
content: match value.content {
|
||||
MessageContent::SingleText(text) => text,
|
||||
MessageContent::MultipleChunks(chunks) => chunks
|
||||
Some(MessageContent::SingleText(text)) => text,
|
||||
Some(MessageContent::MultipleChunks(chunks)) => chunks
|
||||
.into_iter()
|
||||
.map(|chunk| match chunk {
|
||||
MessageChunk::Text { text } => text,
|
||||
@ -1241,6 +1244,7 @@ impl From<Message> for TextMessage {
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
None => String::new(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user