From 16ba2f5a2bddde0b9f9b2cda034f29df9b3b4c3f Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 3 Dec 2024 12:11:17 +0100 Subject: [PATCH] feat(backend): fix main.rs retrieving the tokenizer --- Cargo.lock | 55 +------------ backends/trtllm/Cargo.toml | 9 ++- backends/trtllm/src/main.rs | 154 ++++++++++++++++++++++-------------- router/src/server.rs | 4 +- 4 files changed, 104 insertions(+), 118 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 72f70fdc..263dc566 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2850,20 +2850,6 @@ dependencies = [ "urlencoding", ] -[[package]] -name = "opentelemetry" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96" -dependencies = [ - "futures-core", - "futures-sink", - "js-sys", - "once_cell", - "pin-project-lite", - "thiserror", -] - [[package]] name = "opentelemetry-otlp" version = "0.13.0" @@ -2963,24 +2949,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "opentelemetry_sdk" -version = "0.24.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df" -dependencies = [ - "async-trait", - "futures-channel", - "futures-executor", - "futures-util", - "glob", - "once_cell", - "opentelemetry 0.24.0", - "percent-encoding", - "rand", - "thiserror", -] - [[package]] name = "option-ext" version = "0.2.0" @@ -4369,7 +4337,6 @@ dependencies = [ name = "text-generation-backends-trtllm" version = "2.4.2-dev0" dependencies = [ - "async-stream", "async-trait", "clap 4.5.21", "cmake", @@ -4377,16 +4344,14 @@ dependencies = [ "cxx-build", "hashbrown 0.14.5", "hf-hub", - "log", "pkg-config", + "pyo3", "text-generation-router", "thiserror", "tokenizers", "tokio", "tokio-stream", "tracing", - "tracing-opentelemetry 0.25.0", - "tracing-subscriber", ] [[package]] @@ -5086,24 +5051,6 @@ dependencies = [ "web-time 0.2.4", ] -[[package]] -name = "tracing-opentelemetry" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b" -dependencies = [ - "js-sys", - "once_cell", - "opentelemetry 0.24.0", - "opentelemetry_sdk 0.24.1", - "smallvec", - "tracing", - "tracing-core", - "tracing-log 0.2.0", - "tracing-subscriber", - "web-time 1.1.0", -] - [[package]] name = "tracing-opentelemetry-instrumentation-sdk" version = "0.16.0" diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 97ef1a76..5d907109 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -7,20 +7,21 @@ homepage.workspace = true [dependencies] async-trait = "0.1" -async-stream = "0.3" +#async-stream = "0.3" clap = { version = "4.5", features = ["derive"] } cxx = "1.0" hashbrown = "0.14" hf-hub = { workspace = true } -log = { version = "0.4", features = [] } +#log = { version = "0.4", features = [] } text-generation-router = { path = "../../router" } tokenizers = { workspace = true } tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.15" thiserror = "1.0.63" tracing = "0.1" -tracing-opentelemetry = "0.25" -tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +#tracing-opentelemetry = "0.25" +#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +pyo3 = { workspace = true } [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 8ab8c533..9c76bafa 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -3,14 +3,13 @@ use std::path::{Path, PathBuf}; use clap::Parser; use hf_hub::api::tokio::{Api, ApiBuilder}; use hf_hub::{Cache, Repo, RepoType}; -use tokenizers::Tokenizer; use tracing::info; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackendV2; -use text_generation_router::server::get_base_tokenizer; use text_generation_router::usage_stats::UsageStatsLevel; -use text_generation_router::{server, HubTokenizerConfig}; +use text_generation_router::{server, HubTokenizerConfig, Tokenizer}; +use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer}; /// App Configuration #[derive(Parser, Debug)] @@ -61,7 +60,7 @@ struct Args { #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] executor_worker: PathBuf, #[clap(default_value = "on", long, env)] - usage_stats: usage_stats::UsageStatsLevel, + usage_stats: UsageStatsLevel, #[clap(default_value = "2000000", long, env)] payload_limit: usize, } @@ -126,18 +125,18 @@ async fn get_tokenizer( // Load tokenizer and model info let ( - tokenizer_filename, - _config_filename, - tokenizer_config_filename, + config_filename, + _tokenizer_config_filename, _preprocessor_config_filename, _processor_config_filename, + _model_info ) = match api { Type::None => ( - Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), Some(local_path.join("preprocessor_config.json")), Some(local_path.join("processor_config.json")), + None ), Type::Api(api) => { let api_repo = api.repo(Repo::with_revision( @@ -146,21 +145,24 @@ async fn get_tokenizer( revision.unwrap_or_else(|| "main").to_string(), )); - let tokenizer_filename = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Some(tokenizer_filename), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; + let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; ( - tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, processor_config_filename, + model_info, ) } Type::Cache(cache) => { @@ -170,24 +172,55 @@ async fn get_tokenizer( revision.clone().unwrap_or_else(|| "main").to_string(), )); ( - repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), repo.get("preprocessor_config.json"), repo.get("processor_config.json"), + None ) } }; // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path - { - HubTokenizerConfig::from_file(filename) - } else { - tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + // let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + // { + // HubTokenizerConfig::from_file(filename) + // } else { + // tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + // }; + + // let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + // tracing::warn!("Could not find tokenizer config locally and no API specified"); + // HubTokenizerConfig::default() + // }); + + let tokenizer: Tokenizer = { + use pyo3::prelude::*; + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?; + Ok(()) + }) + .inspect_err(|err| { + tracing::error!("Failed to import python tokenizer {err}"); + }) + .or_else(|err| { + let out = legacy_tokenizer_handle(config_filename.as_ref()); + out.ok_or(err) + }) + .expect("We cannot load a tokenizer"); + let filename = "out/tokenizer.json"; + if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { + Tokenizer::Rust(tok) + } else { + Tokenizer::Python { + tokenizer_name: tokenizer_name.to_string(), + revision: revision.map(|revision| revision.to_string()), + trust_remote_code: false, + } + } }; - tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()) + Some(tokenizer) } #[tokio::main] @@ -258,50 +291,55 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { } // Create the backend - let tokenizer = get_tokenizer( + match get_tokenizer( &tokenizer_name, tokenizer_config_path.as_deref(), revision.as_deref(), ) .await - .expect("Failed to retrieve tokenizer implementation"); + .expect("Failed to retrieve tokenizer implementation") { + Tokenizer::Python { .. } => { + Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string())) + } + Tokenizer::Rust(tokenizer) => { + info!("Successfully retrieved tokenizer {}", &tokenizer_name); + let backend = TensorRtLlmBackendV2::new( + tokenizer, + model_id, + executor_worker, + max_concurrent_requests, + )?; - info!("Successfully retrieved tokenizer {}", &tokenizer_name); - let backend = TensorRtLlmBackendV2::new( - tokenizer, - model_id, - executor_worker, - max_concurrent_requests, - )?; + info!("Successfully created backend"); - info!("Successfully created backend"); + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + auth_token, + tokenizer_name, + tokenizer_config_path, + revision, + false, + hostname, + port, + cors_allow_origin, + false, + None, + None, + true, + max_client_batch_size, + usage_stats, + payload_limit, + ).await?; + Ok(()) + } + } - // Run server - server::run( - backend, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - validation_workers, - auth_token, - tokenizer_name, - tokenizer_config_path, - revision, - false, - hostname, - port, - cors_allow_origin, - false, - None, - None, - true, - max_client_batch_size, - usage_stats, - payload_limit, - ) - .await?; - Ok(()) } diff --git a/router/src/server.rs b/router/src/server.rs index f253cb63..f0b30867 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1588,7 +1588,7 @@ pub fn schema() -> ApiDoc { ApiDoc } -fn py_resolve_tokenizer( +pub fn py_resolve_tokenizer( py: pyo3::Python, tokenizer_name: &str, revision: Option<&str>, @@ -1614,7 +1614,7 @@ fn py_resolve_tokenizer( Ok(()) } -fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { +pub fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { // XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3 // and state-spaces/mamba-130m tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");