diff --git a/router/src/infer.rs b/router/src/infer.rs index eef42989..1447e756 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,7 +2,8 @@ use crate::validation::{Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token, + HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, + TextMessage, Token, }; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; @@ -67,6 +68,7 @@ impl Infer { speculate: u32, generation_health: Arc, tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding, 16, window_size, speculate); @@ -89,6 +91,7 @@ impl Infer { let chat_template = tokenizer_config .chat_template + .or(processor_config.chat_template) .and_then(|t| match t { ChatTemplateVersions::Single(template) => Some(template), ChatTemplateVersions::Multiple(templates) => templates @@ -98,7 +101,10 @@ impl Infer { }) .map(|t| { // .strip() is not supported in minijinja - let t = t.replace(".strip()", " | trim"); + // .capitalize() is not supported in minijinja but we can use | capitalize + let t = t + .replace(".strip()", " | trim") + .replace(".capitalize()", " | capitalize"); ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) }); diff --git a/router/src/lib.rs b/router/src/lib.rs index ba1d9acc..9b3283df 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -80,6 +80,20 @@ impl HubTokenizerConfig { } } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct HubProcessorConfig { + pub chat_template: Option, + pub image_seq_len: usize, + pub processor_class: Option, +} + +impl HubProcessorConfig { + pub fn from_file>(filename: P) -> Option { + let content = std::fs::read_to_string(filename).ok()?; + serde_json::from_str(&content).ok() + } +} + #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { diff --git a/router/src/main.rs b/router/src/main.rs index b11c4526..b526367c 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -14,7 +14,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::config::Config; -use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; +use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use thiserror::Error; use tokenizers::Tokenizer; use tower_http::cors::AllowOrigin; @@ -206,11 +206,18 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer and model info - let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api { + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + processor_config_filename, + model_info, + ) = match api { Type::None => ( Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("processor_config.json")), None, ), Type::Api(api) => { @@ -226,6 +233,7 @@ async fn main() -> Result<(), RouterError> { }; let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let model_info = if let Some(model_info) = get_model_info(&api_repo).await { Some(model_info) @@ -237,6 +245,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_filename, config_filename, tokenizer_config_filename, + processor_config_filename, model_info, ) } @@ -250,6 +259,7 @@ async fn main() -> Result<(), RouterError> { repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), + repo.get("processor_config.json"), None, ) } @@ -286,6 +296,10 @@ async fn main() -> Result<(), RouterError> { HubTokenizerConfig::default() }); + let processor_config = processor_config_filename + .and_then(HubProcessorConfig::from_file) + .unwrap_or_default(); + tracing::info!("Using config {config:?}"); if tokenizer.is_none() { tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); @@ -397,6 +411,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, + processor_config, messages_api_enabled, disable_grammar_support, max_client_batch_size, diff --git a/router/src/server.rs b/router/src/server.rs index e7570ded..64ec31eb 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,9 +5,9 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, - PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, - Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Infer, + Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, + TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -1395,6 +1395,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, @@ -1495,6 +1496,7 @@ pub async fn run( shard_info.speculate, generation_health, tokenizer_config, + processor_config, ); // Duration buckets