This commit is contained in:
Wang, Yi 2025-06-18 15:21:35 +02:00 committed by GitHub
commit ede406a32c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 14 deletions

View File

@ -6,7 +6,7 @@ use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::Tool;
use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token,
Message, PrefillToken, TextMessage, Token,
};
use async_stream::stream;
use async_trait::async_trait;
@ -68,7 +68,16 @@ impl Infer {
tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig,
) -> Self {
let chat_template = tokenizer_config
let chat_template = if matches!(
processor_config.processor_class.as_deref(),
Some("Llama4Processor")
| Some("LlavaNextProcessor")
| Some("Idefics2Processor")
| Some("Idefics3Processor")
) {
None // Do not use chat_template
} else {
tokenizer_config
.chat_template
.or(processor_config.chat_template)
.and_then(|t| match t {
@ -78,7 +87,10 @@ impl Infer {
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
.map(|t| {
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
})
};
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -229,6 +241,15 @@ impl Infer {
messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
if self.chat_template.is_none() {
let textmessages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
let message_str = textmessages
.iter()
.map(|msg| msg.content.clone()) // Extract content from each `TextMessage`
.collect::<Vec<String>>() // Collect all content into a vector
.join("\n"); // Join all content into a single string separated by newlines
return Ok(message_str);
}
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?

View File

@ -210,7 +210,7 @@ pub struct Llama4Processor {
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>,
pub image_seq_len: usize,
pub image_seq_len: Option<usize>,
pub processor_class: Option<String>,
}
@ -1008,7 +1008,7 @@ impl ChatRequest {
Ok((
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
add_special_tokens: infer.chat_template.is_none(),
parameters: GenerateParameters {
best_of: None,
temperature,