This commit is contained in:
Wang, Yi 2025-06-18 15:21:35 +02:00 committed by GitHub
commit ede406a32c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 14 deletions

View File

@ -6,7 +6,7 @@ use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::Tool; use crate::Tool;
use crate::{ use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token, Message, PrefillToken, TextMessage, Token,
}; };
use async_stream::stream; use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
@ -68,17 +68,29 @@ impl Infer {
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig, processor_config: HubProcessorConfig,
) -> Self { ) -> Self {
let chat_template = tokenizer_config let chat_template = if matches!(
.chat_template processor_config.processor_class.as_deref(),
.or(processor_config.chat_template) Some("Llama4Processor")
.and_then(|t| match t { | Some("LlavaNextProcessor")
ChatTemplateVersions::Single(template) => Some(template), | Some("Idefics2Processor")
ChatTemplateVersions::Multiple(templates) => templates | Some("Idefics3Processor")
.into_iter() ) {
.find(|t| t.name == "default") None // Do not use chat_template
.map(|t| t.template), } else {
}) tokenizer_config
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); .chat_template
.or(processor_config.chat_template)
.and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
.into_iter()
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| {
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
})
};
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -229,6 +241,15 @@ impl Infer {
messages: Vec<Message>, messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>, tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
if self.chat_template.is_none() {
let textmessages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
let message_str = textmessages
.iter()
.map(|msg| msg.content.clone()) // Extract content from each `TextMessage`
.collect::<Vec<String>>() // Collect all content into a vector
.join("\n"); // Join all content into a single string separated by newlines
return Ok(message_str);
}
self.chat_template self.chat_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?

View File

@ -210,7 +210,7 @@ pub struct Llama4Processor {
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig { pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>, pub chat_template: Option<ChatTemplateVersions>,
pub image_seq_len: usize, pub image_seq_len: Option<usize>,
pub processor_class: Option<String>, pub processor_class: Option<String>,
} }
@ -1008,7 +1008,7 @@ impl ChatRequest {
Ok(( Ok((
GenerateRequest { GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
add_special_tokens: false, add_special_tokens: infer.chat_template.is_none(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,