diff --git a/Cargo.lock b/Cargo.lock index 00c7f005..06a61853 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2118,6 +2118,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "metrics" version = "0.23.0" @@ -3112,6 +3121,69 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "pyo3" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.76", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.76", +] + [[package]] name = "qoi" version = "0.4.1" @@ -4068,7 +4140,7 @@ dependencies = [ "pkg-config", "text-generation-router", "thiserror", - "tokenizers", + "tokenizers 0.19.1", "tokio", "tokio-stream", "tracing", @@ -4091,7 +4163,7 @@ dependencies = [ "tabled", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.20.0", "tokio", "tracing", "tracing-subscriber", @@ -4161,6 +4233,7 @@ dependencies = [ "once_cell", "opentelemetry 0.20.0", "opentelemetry-otlp", + "pyo3", "rand", "regex", "reqwest", @@ -4168,7 +4241,7 @@ dependencies = [ "serde_json", "sysinfo", "thiserror", - "tokenizers", + "tokenizers 0.20.0", "tokio", "tokio-stream", "tower-http", @@ -4219,7 +4292,7 @@ dependencies = [ "slotmap", "text-generation-router", "thiserror", - "tokenizers", + "tokenizers 0.20.0", "tokio", "tokio-stream", "tonic 0.10.2", @@ -4374,6 +4447,39 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tokenizers" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8a24d7f7d6be5b9d1377418b893ab1808af0074f5d1bb2c64784452ddd2aa70" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.4", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.39.3" @@ -4839,6 +4945,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index 79fda15d..a50bba24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ homepage = "https://github.com/huggingface/text-generation-inference" [workspace.dependencies] base64 = "0.22.0" -tokenizers = { version = "0.19.1", features = ["http"] } +tokenizers = { version = "0.20.0", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 978a495c..2bbb6753 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -357,6 +357,7 @@ impl State { let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = (block_allocation, &self.block_allocator) { + tracing::debug!("Allocating {tokens} with {input_ids:?}"); match block_allocator.allocate(tokens, input_ids).await { None => { // Entry is over budget diff --git a/router/Cargo.toml b/router/Cargo.toml index 5c328e8a..6a752db6 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [ ] } csv = "1.3.0" ureq = "=2.9" +pyo3 = { version = "0.22.2", features = ["auto-initialize"] } [build-dependencies] diff --git a/router/src/server.rs b/router/src/server.rs index 6a04ab00..8bd49b93 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -41,6 +41,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use pyo3::types::IntoPyDict; use serde_json::Value; use std::convert::Infallible; use std::fs::File; @@ -48,7 +49,6 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use thiserror::Error; -use tokenizers::processors::template::TemplateProcessing; use tokenizers::Tokenizer; use tokio::select; use tokio::signal; @@ -1860,18 +1860,34 @@ pub async fn run( }); let tokenizer: Option = tokenizer_filename.and_then(|filename| { - let mut tokenizer = Tokenizer::from_file(filename).ok(); - if let Some(tokenizer) = &mut tokenizer { - if let Some(class) = &tokenizer_config.tokenizer_class { - if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ - if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { - tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); - tokenizer.with_post_processor(post_processor); - } - } - } - } - tokenizer + use pyo3::prelude::*; + let convert = pyo3::Python::with_gil(|py| -> PyResult<()> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name.to_string(),); + let kwargs = [( + "revision", + revision.clone().unwrap_or_else(|| "main".to_string()), + )] + .into_py_dict_bound(py); + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + let save = tokenizer.getattr("save_pretrained")?; + let args = ("out".to_string(),); + save.call1(args)?; + Ok(()) + }) + .inspect_err(|err| { + tracing::error!("Failed to import python tokenizer {err}"); + }); + let filename = if convert.is_ok() { + // If we have correctly loaded and resaved with transformers + // We might have modified the tokenizer.json according to transformers + "out/tokenizer.json".into() + } else { + filename + }; + Tokenizer::from_file(filename).ok() }); let config: Option = config_filename.and_then(|filename| { @@ -2591,77 +2607,6 @@ pub enum WebServerError { Axum(#[from] axum::BoxError), } -/// Create a post_processor for the LlamaTokenizer -fn create_post_processor( - tokenizer: &Tokenizer, - tokenizer_config: &HubTokenizerConfig, -) -> Result { - let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); - let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); - - let bos_token = tokenizer_config.bos_token.as_ref(); - let eos_token = tokenizer_config.eos_token.as_ref(); - - if add_bos_token && bos_token.is_none() { - panic!("add_bos_token = true but bos_token is None"); - } - - if add_eos_token && eos_token.is_none() { - panic!("add_eos_token = true but eos_token is None"); - } - - let mut single = Vec::new(); - let mut pair = Vec::new(); - let mut special_tokens = Vec::new(); - - if add_bos_token { - if let Some(bos) = bos_token { - let bos_token_id = tokenizer - .token_to_id(bos.as_str()) - .expect("Should have found the bos token id"); - special_tokens.push((bos.as_str(), bos_token_id)); - single.push(format!("{}:0", bos.as_str())); - pair.push(format!("{}:0", bos.as_str())); - } - } - - single.push("$A:0".to_string()); - pair.push("$A:0".to_string()); - - if add_eos_token { - if let Some(eos) = eos_token { - let eos_token_id = tokenizer - .token_to_id(eos.as_str()) - .expect("Should have found the eos token id"); - special_tokens.push((eos.as_str(), eos_token_id)); - single.push(format!("{}:0", eos.as_str())); - pair.push(format!("{}:0", eos.as_str())); - } - } - - if add_bos_token { - if let Some(bos) = bos_token { - pair.push(format!("{}:1", bos.as_str())); - } - } - - pair.push("$B:1".to_string()); - - if add_eos_token { - if let Some(eos) = eos_token { - pair.push(format!("{}:1", eos.as_str())); - } - } - - let post_processor = TemplateProcessing::builder() - .try_single(single)? - .try_pair(pair)? - .special_tokens(special_tokens) - .build()?; - - Ok(post_processor) -} - type PreparedInput = (String, Option, bool); fn prepare_chat_input( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 65180499..834056aa 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -272,6 +272,8 @@ class FlashCausalLMBatch(Batch): assert prefix_len > 0 prefix_len -= 1 + # Commented as it's costly. + # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:]