mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
make content field optional in Message for role=assistant
This commit is contained in:
parent
3b096626e8
commit
3eef08b7a1
@ -1,5 +1,5 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
use crate::{ChatTemplateInputs, Message, MessageChunk, MessageContent, TextMessage, TokenizerConfigToken, Tool};
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
@ -74,14 +74,38 @@ impl ChatTemplate {
|
|||||||
format!("\n---\n{}", tool_prompt)
|
format!("\n---\n{}", tool_prompt)
|
||||||
};
|
};
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
last_message.content.push(MessageChunk::Text { text });
|
if let Some(ref mut content) = last_message.content {
|
||||||
|
content.push(MessageChunk::Text { text });
|
||||||
|
} else {
|
||||||
|
last_message.content = Some(MessageContent::SingleText(text));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Some(tools)
|
Some(tools)
|
||||||
}
|
}
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
let messages: Vec<TextMessage> = messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|m| {
|
||||||
|
if m.role == "assistant" && m.tool_calls.is_some() && m.content.is_none() {
|
||||||
|
// For assistant messages with tool calls but no content,
|
||||||
|
// just use the standard conversion which will handle None content
|
||||||
|
Ok(m.into())
|
||||||
|
} else if m.content.is_none() {
|
||||||
|
// For messages requiring content but having none, return error
|
||||||
|
return Err(InferError::TemplateError(
|
||||||
|
minijinja::Error::new(
|
||||||
|
minijinja::ErrorKind::SyntaxError,
|
||||||
|
"Content is required for this message type",
|
||||||
|
)
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
Ok(m.into())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<_, _>>()?;
|
||||||
|
|
||||||
let final_message = messages.last().cloned();
|
let final_message = messages.last().cloned();
|
||||||
let mut rendered_template = self
|
let mut rendered_template = self
|
||||||
.template
|
.template
|
||||||
|
@ -1189,10 +1189,13 @@ pub struct Message {
|
|||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
role: String,
|
role: String,
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: MessageContent,
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<MessageContent>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
name: Option<String>,
|
name: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
@ -1232,8 +1235,8 @@ impl From<Message> for TextMessage {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: value.role,
|
role: value.role,
|
||||||
content: match value.content {
|
content: match value.content {
|
||||||
MessageContent::SingleText(text) => text,
|
Some(MessageContent::SingleText(text)) => text,
|
||||||
MessageContent::MultipleChunks(chunks) => chunks
|
Some(MessageContent::MultipleChunks(chunks)) => chunks
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|chunk| match chunk {
|
.map(|chunk| match chunk {
|
||||||
MessageChunk::Text { text } => text,
|
MessageChunk::Text { text } => text,
|
||||||
@ -1241,6 +1244,7 @@ impl From<Message> for TextMessage {
|
|||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
|
None => String::new(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user