mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Merge c43954d44c
into bd1bdebb47
This commit is contained in:
commit
ede406a32c
@ -6,7 +6,7 @@ use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
|||||||
use crate::Tool;
|
use crate::Tool;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, TextMessage, Token,
|
||||||
};
|
};
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@ -68,17 +68,29 @@ impl Infer {
|
|||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
processor_config: HubProcessorConfig,
|
processor_config: HubProcessorConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let chat_template = tokenizer_config
|
let chat_template = if matches!(
|
||||||
.chat_template
|
processor_config.processor_class.as_deref(),
|
||||||
.or(processor_config.chat_template)
|
Some("Llama4Processor")
|
||||||
.and_then(|t| match t {
|
| Some("LlavaNextProcessor")
|
||||||
ChatTemplateVersions::Single(template) => Some(template),
|
| Some("Idefics2Processor")
|
||||||
ChatTemplateVersions::Multiple(templates) => templates
|
| Some("Idefics3Processor")
|
||||||
.into_iter()
|
) {
|
||||||
.find(|t| t.name == "default")
|
None // Do not use chat_template
|
||||||
.map(|t| t.template),
|
} else {
|
||||||
})
|
tokenizer_config
|
||||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
.chat_template
|
||||||
|
.or(processor_config.chat_template)
|
||||||
|
.and_then(|t| match t {
|
||||||
|
ChatTemplateVersions::Single(template) => Some(template),
|
||||||
|
ChatTemplateVersions::Multiple(templates) => templates
|
||||||
|
.into_iter()
|
||||||
|
.find(|t| t.name == "default")
|
||||||
|
.map(|t| t.template),
|
||||||
|
})
|
||||||
|
.map(|t| {
|
||||||
|
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
@ -229,6 +241,15 @@ impl Infer {
|
|||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
|
if self.chat_template.is_none() {
|
||||||
|
let textmessages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
let message_str = textmessages
|
||||||
|
.iter()
|
||||||
|
.map(|msg| msg.content.clone()) // Extract content from each `TextMessage`
|
||||||
|
.collect::<Vec<String>>() // Collect all content into a vector
|
||||||
|
.join("\n"); // Join all content into a single string separated by newlines
|
||||||
|
return Ok(message_str);
|
||||||
|
}
|
||||||
self.chat_template
|
self.chat_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
|
@ -210,7 +210,7 @@ pub struct Llama4Processor {
|
|||||||
#[derive(Debug, Clone, Deserialize, Default)]
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
pub struct HubProcessorConfig {
|
pub struct HubProcessorConfig {
|
||||||
pub chat_template: Option<ChatTemplateVersions>,
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
pub image_seq_len: usize,
|
pub image_seq_len: Option<usize>,
|
||||||
pub processor_class: Option<String>,
|
pub processor_class: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1008,7 +1008,7 @@ impl ChatRequest {
|
|||||||
Ok((
|
Ok((
|
||||||
GenerateRequest {
|
GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
add_special_tokens: false,
|
add_special_tokens: infer.chat_template.is_none(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature,
|
temperature,
|
||||||
|
Loading…
Reference in New Issue
Block a user