mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
Merge c43954d44c
into bd1bdebb47
This commit is contained in:
commit
ede406a32c
@ -6,7 +6,7 @@ use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::Tool;
|
||||
use crate::{
|
||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
Message, PrefillToken, TextMessage, Token,
|
||||
};
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
@ -68,7 +68,16 @@ impl Infer {
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
processor_config: HubProcessorConfig,
|
||||
) -> Self {
|
||||
let chat_template = tokenizer_config
|
||||
let chat_template = if matches!(
|
||||
processor_config.processor_class.as_deref(),
|
||||
Some("Llama4Processor")
|
||||
| Some("LlavaNextProcessor")
|
||||
| Some("Idefics2Processor")
|
||||
| Some("Idefics3Processor")
|
||||
) {
|
||||
None // Do not use chat_template
|
||||
} else {
|
||||
tokenizer_config
|
||||
.chat_template
|
||||
.or(processor_config.chat_template)
|
||||
.and_then(|t| match t {
|
||||
@ -78,7 +87,10 @@ impl Infer {
|
||||
.find(|t| t.name == "default")
|
||||
.map(|t| t.template),
|
||||
})
|
||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||
.map(|t| {
|
||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||
})
|
||||
};
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
@ -229,6 +241,15 @@ impl Infer {
|
||||
messages: Vec<Message>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.chat_template.is_none() {
|
||||
let textmessages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
let message_str = textmessages
|
||||
.iter()
|
||||
.map(|msg| msg.content.clone()) // Extract content from each `TextMessage`
|
||||
.collect::<Vec<String>>() // Collect all content into a vector
|
||||
.join("\n"); // Join all content into a single string separated by newlines
|
||||
return Ok(message_str);
|
||||
}
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
|
@ -210,7 +210,7 @@ pub struct Llama4Processor {
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubProcessorConfig {
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
pub image_seq_len: usize,
|
||||
pub image_seq_len: Option<usize>,
|
||||
pub processor_class: Option<String>,
|
||||
}
|
||||
|
||||
@ -1008,7 +1008,7 @@ impl ChatRequest {
|
||||
Ok((
|
||||
GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
add_special_tokens: infer.chat_template.is_none(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
|
Loading…
Reference in New Issue
Block a user