feat(backend): fix main.rs retrieving the tokenizer

This commit is contained in:
Morgan Funtowicz 2024-12-03 12:11:17 +01:00
parent 874bc28d6c
commit 16ba2f5a2b
4 changed files with 104 additions and 118 deletions

55
Cargo.lock generated
View File

@ -2850,20 +2850,6 @@ dependencies = [
"urlencoding", "urlencoding",
] ]
[[package]]
name = "opentelemetry"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
dependencies = [
"futures-core",
"futures-sink",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
]
[[package]] [[package]]
name = "opentelemetry-otlp" name = "opentelemetry-otlp"
version = "0.13.0" version = "0.13.0"
@ -2963,24 +2949,6 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "opentelemetry_sdk"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
dependencies = [
"async-trait",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"once_cell",
"opentelemetry 0.24.0",
"percent-encoding",
"rand",
"thiserror",
]
[[package]] [[package]]
name = "option-ext" name = "option-ext"
version = "0.2.0" version = "0.2.0"
@ -4369,7 +4337,6 @@ dependencies = [
name = "text-generation-backends-trtllm" name = "text-generation-backends-trtllm"
version = "2.4.2-dev0" version = "2.4.2-dev0"
dependencies = [ dependencies = [
"async-stream",
"async-trait", "async-trait",
"clap 4.5.21", "clap 4.5.21",
"cmake", "cmake",
@ -4377,16 +4344,14 @@ dependencies = [
"cxx-build", "cxx-build",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"hf-hub", "hf-hub",
"log",
"pkg-config", "pkg-config",
"pyo3",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"tracing-opentelemetry 0.25.0",
"tracing-subscriber",
] ]
[[package]] [[package]]
@ -5086,24 +5051,6 @@ dependencies = [
"web-time 0.2.4", "web-time 0.2.4",
] ]
[[package]]
name = "tracing-opentelemetry"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry 0.24.0",
"opentelemetry_sdk 0.24.1",
"smallvec",
"tracing",
"tracing-core",
"tracing-log 0.2.0",
"tracing-subscriber",
"web-time 1.1.0",
]
[[package]] [[package]]
name = "tracing-opentelemetry-instrumentation-sdk" name = "tracing-opentelemetry-instrumentation-sdk"
version = "0.16.0" version = "0.16.0"

View File

@ -7,20 +7,21 @@ homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
async-stream = "0.3" #async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
hashbrown = "0.14" hashbrown = "0.14"
hf-hub = { workspace = true } hf-hub = { workspace = true }
log = { version = "0.4", features = [] } #log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.15"
thiserror = "1.0.63" thiserror = "1.0.63"
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.25" #tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } #tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
pyo3 = { workspace = true }
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"

View File

@ -3,14 +3,13 @@ use std::path::{Path, PathBuf};
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder}; use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info; use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2; use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig}; use text_generation_router::{server, HubTokenizerConfig, Tokenizer};
use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -61,7 +60,7 @@ struct Args {
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf, executor_worker: PathBuf,
#[clap(default_value = "on", long, env)] #[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel, usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)] #[clap(default_value = "2000000", long, env)]
payload_limit: usize, payload_limit: usize,
} }
@ -126,18 +125,18 @@ async fn get_tokenizer(
// Load tokenizer and model info // Load tokenizer and model info
let ( let (
tokenizer_filename, config_filename,
_config_filename, _tokenizer_config_filename,
tokenizer_config_filename,
_preprocessor_config_filename, _preprocessor_config_filename,
_processor_config_filename, _processor_config_filename,
_model_info
) = match api { ) = match api {
Type::None => ( Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")), Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")), Some(local_path.join("processor_config.json")),
None
), ),
Type::Api(api) => { Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision( let api_repo = api.repo(Repo::with_revision(
@ -146,21 +145,24 @@ async fn get_tokenizer(
revision.unwrap_or_else(|| "main").to_string(), revision.unwrap_or_else(|| "main").to_string(),
)); ));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
( (
tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename, preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info,
) )
} }
Type::Cache(cache) => { Type::Cache(cache) => {
@ -170,24 +172,55 @@ async fn get_tokenizer(
revision.clone().unwrap_or_else(|| "main").to_string(), revision.clone().unwrap_or_else(|| "main").to_string(),
)); ));
( (
repo.get("tokenizer.json"),
repo.get("config.json"), repo.get("config.json"),
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"), repo.get("preprocessor_config.json"),
repo.get("processor_config.json"), repo.get("processor_config.json"),
None
) )
} }
}; };
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path // let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{ // {
HubTokenizerConfig::from_file(filename) // HubTokenizerConfig::from_file(filename)
} else { // } else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) // tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
// };
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
// tracing::warn!("Could not find tokenizer config locally and no API specified");
// HubTokenizerConfig::default()
// });
let tokenizer: Tokenizer = {
use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
})
.or_else(|err| {
let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err)
})
.expect("We cannot load a tokenizer");
let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok)
} else {
Tokenizer::Python {
tokenizer_name: tokenizer_name.to_string(),
revision: revision.map(|revision| revision.to_string()),
trust_remote_code: false,
}
}
}; };
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()) Some(tokenizer)
} }
#[tokio::main] #[tokio::main]
@ -258,50 +291,55 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
} }
// Create the backend // Create the backend
let tokenizer = get_tokenizer( match get_tokenizer(
&tokenizer_name, &tokenizer_name,
tokenizer_config_path.as_deref(), tokenizer_config_path.as_deref(),
revision.as_deref(), revision.as_deref(),
) )
.await .await
.expect("Failed to retrieve tokenizer implementation"); .expect("Failed to retrieve tokenizer implementation") {
Tokenizer::Python { .. } => {
Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string()))
}
Tokenizer::Rust(tokenizer) => {
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new(
tokenizer,
model_id,
executor_worker,
max_concurrent_requests,
)?;
info!("Successfully retrieved tokenizer {}", &tokenizer_name); info!("Successfully created backend");
let backend = TensorRtLlmBackendV2::new(
tokenizer,
model_id,
executor_worker,
max_concurrent_requests,
)?;
info!("Successfully created backend"); // Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
auth_token,
tokenizer_name,
tokenizer_config_path,
revision,
false,
hostname,
port,
cors_allow_origin,
false,
None,
None,
true,
max_client_batch_size,
usage_stats,
payload_limit,
).await?;
Ok(())
}
}
// Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
auth_token,
tokenizer_name,
tokenizer_config_path,
revision,
false,
hostname,
port,
cors_allow_origin,
false,
None,
None,
true,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())
} }

View File

@ -1588,7 +1588,7 @@ pub fn schema() -> ApiDoc {
ApiDoc ApiDoc
} }
fn py_resolve_tokenizer( pub fn py_resolve_tokenizer(
py: pyo3::Python, py: pyo3::Python,
tokenizer_name: &str, tokenizer_name: &str,
revision: Option<&str>, revision: Option<&str>,
@ -1614,7 +1614,7 @@ fn py_resolve_tokenizer(
Ok(()) Ok(())
} }
fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { pub fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
// XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3 // XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3
// and state-spaces/mamba-130m // and state-spaces/mamba-130m
tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization"); tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");