use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, TokenizerTrait, }; use crate::{PyTokenizer, Tokenizer}; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; use outlines_core::json_schema::to_regex as json_schema_to_regex; use rand::{thread_rng, Rng}; use serde_json::Value; /// Payload validation logic use std::cmp::min; use std::io::Cursor; use std::iter; use std::sync::Arc; use thiserror::Error; use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; use {once_cell::sync::Lazy, regex::Regex}; static DEFAULT_GENERATION_LENGTH: u32 = 1024; /// Validation #[derive(Debug, Clone)] pub struct Validation { /// Validation parameters max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, disable_grammar_support: bool, /// Channel to communicate with the background tokenization task sender: mpsc::UnboundedSender, } impl Validation { #[allow(clippy::too_many_arguments)] pub(crate) fn new( workers: usize, tokenizer: Tokenizer, config: Option, preprocessor_config: Option, max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, disable_grammar_support: bool, ) -> Self { let workers = if let Tokenizer::Python { .. } = &tokenizer { 1 } else { workers }; // If we have a fast tokenizer let sender = { // Create round robin channel let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); let mut senders = Vec::with_capacity(workers); // Create workers for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); let config_clone = config.clone(); let preprocessor_config_clone = preprocessor_config.clone(); let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); senders.push(tokenizer_sender); // Spawn worker tokio::task::spawn_blocking(move || { tokenizer_worker( tokenizer_clone, config_clone, preprocessor_config_clone, tokenizer_receiver, ) }); } // Create tokenization round robin task tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); validation_sender }; Self { max_best_of, sender, max_stop_sequences, max_top_n_tokens, max_input_length, max_total_tokens, disable_grammar_support, } } #[instrument(skip(self, inputs))] pub async fn tokenize( &self, inputs: String, add_special_tokens: bool, truncate: Option, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { // If we have a fast tokenizer // Create response channel let (response_sender, response_receiver) = oneshot::channel(); // Send request to the background validation task // Unwrap is safe here let _ = &self .sender .send(( (inputs, add_special_tokens, truncate), response_sender, Span::current(), )) .unwrap(); // Await on response channel // Unwrap is safe here let encoding = response_receiver.await.unwrap()?; Ok(encoding) } #[allow(clippy::type_complexity)] #[instrument(skip(self, inputs))] async fn validate_input( &self, inputs: String, add_special_tokens: bool, truncate: Option, max_new_tokens: Option, ) -> Result<(Vec, Option>, usize, u32, u32), ValidationError> { // If we have a fast tokenizer let (encoding, inputs) = self .tokenize(inputs.clone(), add_special_tokens, truncate) .await?; // Create response channel let input_length = if let Some(truncate) = truncate { std::cmp::min(encoding.len(), truncate) } else { encoding.len() }; // Get total tokens let (max_new_tokens, max_total_new_tokens) = if let Some(max_new_tokens) = max_new_tokens { (max_new_tokens, max_new_tokens) } else { // Use the maximum possible number of tokens as default // However, the system will re-queue the request everytime it completes // `DEFAULT_GENERATION_LENGTH` tokens. let max_new_tokens = self.max_total_tokens.saturating_sub(input_length) as u32; ( min(max_new_tokens, DEFAULT_GENERATION_LENGTH), max_new_tokens, ) }; let total_tokens = input_length + max_new_tokens as usize; // Validate MaxTotalTokens if total_tokens > self.max_total_tokens { return Err(ValidationError::MaxTotalTokens( self.max_total_tokens, input_length, max_new_tokens, )); } // Validate InputLength if input_length > self.max_input_length { return Err(ValidationError::InputLength( self.max_input_length, input_length, )); } let ids = encoding.get_ids(); let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok(( inputs, Some(input_ids), input_length, max_new_tokens, max_total_new_tokens, )) } /// Validate a payload and get the number of tokens in the input #[instrument(skip_all)] pub(crate) async fn validate( &self, request: GenerateRequest, ) -> Result { let GenerateParameters { best_of, temperature, repetition_penalty, frequency_penalty, top_k, top_p, typical_p, do_sample, max_new_tokens, stop: stop_sequences, truncate, seed, watermark, decoder_input_details, top_n_tokens, grammar, adapter_id, .. } = request.parameters; // sampling must be true when best_of > 1 let best_of = best_of.unwrap_or(1); let sampling = do_sample || temperature.is_some() || top_k.is_some() || top_p.is_some() || typical_p.is_some(); if best_of > 1 && !sampling { return Err(BestOfSampling); } let temperature = temperature.unwrap_or(1.0); if temperature <= 0.0 { return Err(ValidationError::Temperature); } let repetition_penalty = repetition_penalty.unwrap_or(1.0); if repetition_penalty <= 0.0 { return Err(ValidationError::RepetitionPenalty); } let frequency_penalty = frequency_penalty.unwrap_or(0.0); if !(-2.0..=2.0).contains(&frequency_penalty) { return Err(ValidationError::FrequencyPenalty); } // Different because the proto default value is not a valid value // for the user let top_p = top_p .map(|value| { if value <= 0.0 || value >= 1.0 { return Err(ValidationError::TopP); } Ok(value) }) .unwrap_or(Ok(1.0))?; let typical_p = typical_p .map(|value| { if value <= 0.0 || value >= 1.0 { return Err(ValidationError::TypicalP); } Ok(value) }) .unwrap_or(Ok(1.0))?; let top_k: u32 = top_k .map(|value| { if value <= 0 { return Err(ValidationError::TopK); } Ok(value as u32) }) .unwrap_or(Ok(0))?; if max_new_tokens == Some(0) { return Err(ValidationError::NegativeMaxNewTokens); } if stop_sequences.len() > self.max_stop_sequences { return Err(ValidationError::StopSequence( self.max_stop_sequences, stop_sequences.len(), )); } // If seed is None, assign a random one let seed = match seed { None => thread_rng().gen(), Some(seed) => { if best_of > 1 { return Err(BestOfSeed); } seed } }; let top_n_tokens = top_n_tokens .map(|value| { if value > self.max_top_n_tokens { return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value)); } Ok(value) }) .unwrap_or(Ok(0))?; // Check if inputs is empty if request.inputs.is_empty() { return Err(EmptyInput); } // Check if truncate is strictly positive and less than max_input_length let truncate = truncate .map(|value| { if value == 0 || value > self.max_input_length { return Err(ValidationError::Truncate(self.max_input_length, value)); } Ok(Some(value)) }) .unwrap_or(Ok(None))?; // Validate inputs let (inputs, input_ids, input_length, max_new_tokens, max_total_new_tokens) = self .validate_input( request.inputs, request.add_special_tokens, truncate, max_new_tokens, ) .await?; // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar // NOTE: this is currently difficult because we need the tokenizer in Python to build // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM // compiler and use that to build the FSM here. // Validate grammar and unpack the grammar and type for the proto message let grammar = match grammar { Some(grammar) => { // Ensure that grammar is not set if it's not supported if self.disable_grammar_support { return Err(ValidationError::Grammar); } let valid_grammar = match grammar { GrammarType::Json(json) => { let json = match json { // if value is a string, we need to parse it again to make sure its // a valid json Value::String(s) => serde_json::from_str(&s) .map_err(|e| ValidationError::InvalidGrammar(e.to_string())), Value::Object(_) => Ok(json), _ => Err(ValidationError::Grammar), }?; // Check if the json is a valid JSONSchema JSONSchema::options() .with_draft(Draft::Draft202012) .compile(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; // The schema can be valid but lack properties. // We need properties for the grammar to be successfully parsed in Python. // Therefore, we must check and throw an error if properties are missing. json.get("properties") .ok_or(ValidationError::InvalidGrammar( "Grammar must have a 'properties' field".to_string(), ))?; // Do compilation in the router for performance. In the future, we // should also move regex -> automaton compilation in the router, // but this is not yet supported in pure Rust by outlines-core. let grammar_regex = json_schema_to_regex(&json, None, &json) .map_err(ValidationError::RegexFromSchema)?; ValidGrammar::Regex(grammar_regex.to_string()) } GrammarType::Regex(regex) => ValidGrammar::Regex(regex), }; Some(valid_grammar) } None => None, }; let parameters = ValidParameters { temperature, repetition_penalty, frequency_penalty, top_k, top_p, typical_p, do_sample, seed, watermark, grammar, }; let stopping_parameters = ValidStoppingParameters { max_new_tokens, max_total_new_tokens, stop_sequences, ignore_eos_token: false, }; metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64); Ok(ValidGenerateRequest { inputs, input_ids: input_ids.map(Arc::new), add_special_tokens: request.add_special_tokens, decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, stopping_parameters, top_n_tokens, adapter_id, }) } /// Validate the best_of parameter #[instrument(skip_all)] pub(crate) fn validate_best_of(&self, best_of: usize) -> Result { if self.max_best_of == 1 && best_of != 1 { return Err(ValidationError::BestOfDisabled); } if best_of > self.max_best_of { return Err(ValidationError::BestOf(self.max_best_of, best_of)); } Ok(best_of) } } /// Round robin tokenization task async fn round_robin_task( mut receiver: mpsc::UnboundedReceiver, senders: Vec>, ) { loop { for sender in &senders { match receiver.recv().await { None => return, Some(request) => sender.send(request).unwrap(), }; } } } /// Start tokenization workers fn tokenizer_worker( tokenizer: Tokenizer, config: Option, preprocessor_config: Option, mut receiver: mpsc::UnboundedReceiver, ) { match tokenizer { Tokenizer::Python { tokenizer_name, revision, trust_remote_code, } => { pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> { let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?; // Loop over requests while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx .send(prepare_input( inputs, truncate, add_special_tokens, &tokenizer, config.as_ref(), preprocessor_config.as_ref(), )) .unwrap_or(()) }) } Ok(()) }) .expect("Failure in python tokenizer worker"); } Tokenizer::Rust(tokenizer) => { while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx .send(prepare_input( inputs, truncate, add_special_tokens, &tokenizer, config.as_ref(), preprocessor_config.as_ref(), )) .unwrap_or(()) }) } } } } fn format_from_mimetype(mimetype: &str) -> Option { match mimetype { "image/png" => Some(ImageFormat::Png), "image/jpeg" => Some(ImageFormat::Jpeg), "image/jpg" => Some(ImageFormat::Jpeg), "image/gif" => Some(ImageFormat::Gif), "image/webp" => Some(ImageFormat::WebP), "image/tiff" => Some(ImageFormat::Tiff), // "image/pnm"=>Some(ImageFormat::Pnm), // "image/tga"=>Some(ImageFormat::Tga), // "image/dds"=>Some(ImageFormat::Dds), // "image/bmp"=>Some(ImageFormat::Bmp), // "image/ico"=>Some(ImageFormat::Ico), // "image/x-exr"=>Some(ImageFormat::OpenExr), _ => None, } } fn format_to_mimetype(format: ImageFormat) -> String { match format { ImageFormat::Png => "image/png", ImageFormat::Jpeg => "image/jpeg", ImageFormat::Gif => "image/gif", ImageFormat::WebP => "image/webp", ImageFormat::Tiff => "image/tiff", _ => "application/octet-stream", } .to_string() } fn fetch_image(input: &str) -> Result<(Vec, String, usize, usize), ValidationError> { if input.starts_with("![](http://") || input.starts_with("![](https://") { let url = &input["![](".len()..input.len() - 1]; let data = reqwest::blocking::get(url)?.bytes()?; let format = image::guess_format(&data)?; // TODO Remove this clone let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?; let height: usize = img.height().try_into()?; let width: usize = img.width().try_into()?; let mimetype = format_to_mimetype(format); Ok((data.to_vec(), mimetype, height, width)) } else if input.starts_with("![](data:") { // Remove ![](....) let content = &input["![](data:".len()..input.len() - 1]; let tokens: Vec<_> = content.split(';').collect(); if tokens.len() != 2 { return Err(ValidationError::InvalidImageContent(content.to_string())); } let mimetype = tokens[0]; let content = tokens[1]; if !content.starts_with("base64,") { return Err(ValidationError::InvalidImageContent(content.to_string())); } let data = STANDARD.decode(content["base64,".len()..].as_bytes())?; let img = if let Some(format) = format_from_mimetype(mimetype) { ImageReader::with_format(Cursor::new(&data), format).decode()? } else { ImageReader::new(Cursor::new(&data)) .with_guessed_format() .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))? .decode()? }; let height: usize = img.height().try_into()?; let width: usize = img.width().try_into()?; Ok((data, mimetype.to_string(), height, width)) } else { Err(ValidationError::InvalidImageContent(input.to_string())) } } fn fetch_video(input: &str) -> Result { if input.starts_with("http://") || input.starts_with("https://") { Ok(input.to_string()) } else { Err(ValidationError::InvalidVideoContent(input.to_string())) } } fn image_tokens( config: &Config, preprocessor_config: Option<&HubPreprocessorConfig>, height: usize, width: usize, ) -> String { use Config::*; use HubPreprocessorConfig::*; match config { Idefics => "".to_string(), Mllama => "<|image|>".to_string(), Idefics2(config) => { const FAKE: &str = ""; const IMAGE: &str = ""; let slots = config.get_number_of_features(height, width); let mut image_string = String::with_capacity(2 * FAKE.len() + slots * IMAGE.len()); image_string.push_str(FAKE); image_string.extend(iter::repeat(IMAGE).take(slots)); image_string.push_str(FAKE); if matches!( preprocessor_config, Some(Idefics2Processor(Idefics2Preprocessor { do_image_splitting: true, .. })) ) { image_string = image_string.repeat(5); }; image_string } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), Qwen2Vl(config) => format!( "<|vision_start|>{:?}<|vision_end|>", "<|image_pad|>".repeat(config.get_number_of_features(height, width)) ), _ => unimplemented!("Images tokens are not supported for this model configuration"), } } fn image_tokens_fixup(config: &Config, text: String) -> String { match config { Config::Idefics2(_) => { const FAKE: &str = ""; text.replace(&format!("{FAKE}{FAKE}"), FAKE) } _ => text, } } /// Get input length and optionally truncate it fn prepare_input( inputs: String, _truncate: Option, add_special_tokens: bool, tokenizer: &T, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); // Add video regex static VIDEO_RE: Lazy = Lazy::new(|| Regex::new(r"