diff --git a/router/src/lib.rs b/router/src/lib.rs index 3132d9b6..ccf8e8d4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -31,70 +31,72 @@ pub enum Tokenizer { Rust(tokenizers::Tokenizer), } -impl Tokenizer { - fn into_owned<'a>(self, py: Python<'a>) -> OwnedTokenizer<'a> { - match self { - Self::Python { - tokenizer_name, - revision, - } => { - let pytok = || -> pyo3::PyResult> { - let transformers = py.import_bound("transformers")?; - let auto = transformers.getattr("AutoTokenizer")?; - let from_pretrained = auto.getattr("from_pretrained")?; - let args = (tokenizer_name.to_string(),); - let kwargs = if let Some(rev) = &revision { - [("revision", rev.to_string())].into_py_dict_bound(py) - } else { - pyo3::types::PyDict::new_bound(py) - }; - let tokenizer = from_pretrained.call(args, Some(&kwargs))?; - Ok(tokenizer) - }() - .expect("Cannot load the tokenizer"); - tracing::info!("Loaded a python tokenizer"); - OwnedTokenizer::Python(pytok) - } - Self::Rust(tok) => OwnedTokenizer::Rust(tok), - } +pub struct PyTokenizer<'a>(pyo3::Bound<'a, pyo3::PyAny>); + +impl<'a> PyTokenizer<'a> { + fn from_py( + py: Python<'a>, + tokenizer_name: String, + revision: Option, + ) -> PyResult> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name,); + let kwargs = if let Some(rev) = &revision { + [("revision", rev.to_string())].into_py_dict_bound(py) + } else { + pyo3::types::PyDict::new_bound(py) + }; + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + tracing::info!("Loaded a python tokenizer"); + Ok(PyTokenizer(tokenizer)) } } -pub enum OwnedTokenizer<'a> { - Python(pyo3::Bound<'a, pyo3::PyAny>), - Rust(tokenizers::Tokenizer), +trait TokenizerTrait { + fn encode_trait( + &self, + query: String, + add_special_tokens: bool, + ) -> Result>; } -impl<'a> OwnedTokenizer<'a> { - fn encode( +impl TokenizerTrait for tokenizers::Tokenizer { + fn encode_trait( &self, query: String, add_special_tokens: bool, ) -> Result> { - match self { - Self::Python(pytok) => { - let py = pytok.py(); - let kwargs = [ - ("text", query.into_py(py)), - ("add_special_tokens", add_special_tokens.into_py(py)), - ] - .into_py_dict_bound(py); - let encode = pytok.getattr("encode")?; - let input_ids: Vec = encode.call((), Some(&kwargs))?.extract()?; - Ok(Encoding::new( - input_ids, - vec![], // type ids - vec![], // tokens (strings) - vec![], // words - vec![], // offsets - vec![], // special_tokens_mask - vec![], // attention_mask - vec![], // overflowing - std::collections::HashMap::new(), //sequence_ranges - )) - } - Self::Rust(tok) => tok.encode(query, add_special_tokens), - } + self.encode(query, add_special_tokens) + } +} + +impl<'a> TokenizerTrait for PyTokenizer<'a> { + fn encode_trait( + &self, + query: String, + add_special_tokens: bool, + ) -> Result> { + let py = self.0.py(); + let kwargs = [ + ("text", query.into_py(py)), + ("add_special_tokens", add_special_tokens.into_py(py)), + ] + .into_py_dict_bound(py); + let encode = self.0.getattr("encode")?; + let input_ids: Vec = encode.call((), Some(&kwargs))?.extract()?; + Ok(Encoding::new( + input_ids, + vec![], // type ids + vec![], // tokens (strings) + vec![], // words + vec![], // offsets + vec![], // special_tokens_mask + vec![], // attention_mask + vec![], // overflowing + std::collections::HashMap::new(), //sequence_ranges + )) } } diff --git a/router/src/validation.rs b/router/src/validation.rs index a02d53a5..8159ede4 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -3,8 +3,9 @@ use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, + TokenizerTrait, }; -use crate::{OwnedTokenizer, Tokenizer}; +use crate::{PyTokenizer, Tokenizer}; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; @@ -434,26 +435,53 @@ fn tokenizer_worker( preprocessor_config: Option, mut receiver: mpsc::UnboundedReceiver, ) { - pyo3::Python::with_gil(|py| { - let tokenizer = tokenizer.into_owned(py); - // Loop over requests - while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = - receiver.blocking_recv() - { - parent_span.in_scope(|| { - response_tx - .send(prepare_input( - inputs, - truncate, - add_special_tokens, - &tokenizer, - config.as_ref(), - preprocessor_config.as_ref(), - )) - .unwrap_or(()) + match tokenizer { + Tokenizer::Python { + tokenizer_name, + revision, + } => { + pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> { + let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?; + // Loop over requests + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { + parent_span.in_scope(|| { + response_tx + .send(prepare_input( + inputs, + truncate, + add_special_tokens, + &tokenizer, + config.as_ref(), + preprocessor_config.as_ref(), + )) + .unwrap_or(()) + }) + } + Ok(()) }) + .expect("Failure in python tokenizer worker"); } - }); + Tokenizer::Rust(tokenizer) => { + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { + parent_span.in_scope(|| { + response_tx + .send(prepare_input( + inputs, + truncate, + add_special_tokens, + &tokenizer, + config.as_ref(), + preprocessor_config.as_ref(), + )) + .unwrap_or(()) + }) + } + } + } } fn format_from_mimetype(mimetype: &str) -> Option { @@ -581,11 +609,11 @@ fn image_tokens_fixup(config: &Config, text: String) -> String { } /// Get input length and optionally truncate it -fn prepare_input( +fn prepare_input( inputs: String, _truncate: Option, add_special_tokens: bool, - tokenizer: &OwnedTokenizer, + tokenizer: &T, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { @@ -622,7 +650,7 @@ fn prepare_input( // Get the number of tokens in the input let encoding = tokenizer - .encode(tokenizer_query, add_special_tokens) + .encode_trait(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; Ok((encoding, input_chunks))