Fixing odd tokenization self modifications on the Rust side (load and

resave in Python).
This commit is contained in:
Nicolas Patry 2024-09-10 19:11:32 +02:00
parent a4e3e8c608
commit b3760df25f
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
6 changed files with 150 additions and 89 deletions

120
Cargo.lock generated
View File

@ -2118,6 +2118,15 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "metrics" name = "metrics"
version = "0.23.0" version = "0.23.0"
@ -3112,6 +3121,69 @@ dependencies = [
"prost 0.12.6", "prost 0.12.6",
] ]
[[package]]
name = "pyo3"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.76",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.76",
]
[[package]] [[package]]
name = "qoi" name = "qoi"
version = "0.4.1" version = "0.4.1"
@ -4068,7 +4140,7 @@ dependencies = [
"pkg-config", "pkg-config",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.19.1",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
@ -4091,7 +4163,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.20.0",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -4161,6 +4233,7 @@ dependencies = [
"once_cell", "once_cell",
"opentelemetry 0.20.0", "opentelemetry 0.20.0",
"opentelemetry-otlp", "opentelemetry-otlp",
"pyo3",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest",
@ -4168,7 +4241,7 @@ dependencies = [
"serde_json", "serde_json",
"sysinfo", "sysinfo",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.20.0",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -4219,7 +4292,7 @@ dependencies = [
"slotmap", "slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.20.0",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tonic 0.10.2", "tonic 0.10.2",
@ -4374,6 +4447,39 @@ dependencies = [
"unicode_categories", "unicode_categories",
] ]
[[package]]
name = "tokenizers"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8a24d7f7d6be5b9d1377418b893ab1808af0074f5d1bb2c64784452ddd2aa70"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.39.3" version = "1.39.3"
@ -4839,6 +4945,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.7.1" version = "0.7.1"

View File

@ -25,7 +25,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies] [workspace.dependencies]
base64 = "0.22.0" base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] } tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] } hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" } metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }

View File

@ -357,6 +357,7 @@ impl State {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator) (block_allocation, &self.block_allocator)
{ {
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await { match block_allocator.allocate(tokens, input_ids).await {
None => { None => {
// Entry is over budget // Entry is over budget

View File

@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] } ] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[build-dependencies] [build-dependencies]

View File

@ -41,6 +41,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION; use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::types::IntoPyDict;
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::fs::File; use std::fs::File;
@ -48,7 +49,6 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use thiserror::Error; use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
@ -1860,18 +1860,34 @@ pub async fn run(
}); });
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| { let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok(); use pyo3::prelude::*;
if let Some(tokenizer) = &mut tokenizer { let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
if let Some(class) = &tokenizer_config.tokenizer_class { let transformers = py.import_bound("transformers")?;
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ let auto = transformers.getattr("AutoTokenizer")?;
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { let from_pretrained = auto.getattr("from_pretrained")?;
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); let args = (tokenizer_name.to_string(),);
tokenizer.with_post_processor(post_processor); let kwargs = [(
} "revision",
} revision.clone().unwrap_or_else(|| "main".to_string()),
} )]
} .into_py_dict_bound(py);
tokenizer let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
});
let filename = if convert.is_ok() {
// If we have correctly loaded and resaved with transformers
// We might have modified the tokenizer.json according to transformers
"out/tokenizer.json".into()
} else {
filename
};
Tokenizer::from_file(filename).ok()
}); });
let config: Option<Config> = config_filename.and_then(|filename| { let config: Option<Config> = config_filename.and_then(|filename| {
@ -2591,77 +2607,6 @@ pub enum WebServerError {
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
} }
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
type PreparedInput = (String, Option<GrammarType>, bool); type PreparedInput = (String, Option<GrammarType>, bool);
fn prepare_chat_input( fn prepare_chat_input(

View File

@ -272,6 +272,8 @@ class FlashCausalLMBatch(Batch):
assert prefix_len > 0 assert prefix_len > 0
prefix_len -= 1 prefix_len -= 1
# Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:] tokenized_input = tokenized_input[prefix_len:]