feat(backend): fix main.rs retrieving the tokenizer

This commit is contained in:
Morgan Funtowicz 2024-12-03 12:11:17 +01:00
parent 874bc28d6c
commit 16ba2f5a2b
4 changed files with 104 additions and 118 deletions

55
Cargo.lock generated
View File

@ -2850,20 +2850,6 @@ dependencies = [
"urlencoding",
]
[[package]]
name = "opentelemetry"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
dependencies = [
"futures-core",
"futures-sink",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
]
[[package]]
name = "opentelemetry-otlp"
version = "0.13.0"
@ -2963,24 +2949,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "opentelemetry_sdk"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
dependencies = [
"async-trait",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"once_cell",
"opentelemetry 0.24.0",
"percent-encoding",
"rand",
"thiserror",
]
[[package]]
name = "option-ext"
version = "0.2.0"
@ -4369,7 +4337,6 @@ dependencies = [
name = "text-generation-backends-trtllm"
version = "2.4.2-dev0"
dependencies = [
"async-stream",
"async-trait",
"clap 4.5.21",
"cmake",
@ -4377,16 +4344,14 @@ dependencies = [
"cxx-build",
"hashbrown 0.14.5",
"hf-hub",
"log",
"pkg-config",
"pyo3",
"text-generation-router",
"thiserror",
"tokenizers",
"tokio",
"tokio-stream",
"tracing",
"tracing-opentelemetry 0.25.0",
"tracing-subscriber",
]
[[package]]
@ -5086,24 +5051,6 @@ dependencies = [
"web-time 0.2.4",
]
[[package]]
name = "tracing-opentelemetry"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry 0.24.0",
"opentelemetry_sdk 0.24.1",
"smallvec",
"tracing",
"tracing-core",
"tracing-log 0.2.0",
"tracing-subscriber",
"web-time 1.1.0",
]
[[package]]
name = "tracing-opentelemetry-instrumentation-sdk"
version = "0.16.0"

View File

@ -7,20 +7,21 @@ homepage.workspace = true
[dependencies]
async-trait = "0.1"
async-stream = "0.3"
#async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0"
hashbrown = "0.14"
hf-hub = { workspace = true }
log = { version = "0.4", features = [] }
#log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" }
tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15"
thiserror = "1.0.63"
tracing = "0.1"
tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
#tracing-opentelemetry = "0.25"
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
pyo3 = { workspace = true }
[build-dependencies]
cmake = "0.1"

View File

@ -3,14 +3,13 @@ use std::path::{Path, PathBuf};
use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};
use text_generation_router::{server, HubTokenizerConfig, Tokenizer};
use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer};
/// App Configuration
#[derive(Parser, Debug)]
@ -61,7 +60,7 @@ struct Args {
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
@ -126,18 +125,18 @@ async fn get_tokenizer(
// Load tokenizer and model info
let (
tokenizer_filename,
_config_filename,
tokenizer_config_filename,
config_filename,
_tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
_model_info
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
@ -146,21 +145,24 @@ async fn get_tokenizer(
revision.unwrap_or_else(|| "main").to_string(),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
@ -170,24 +172,55 @@ async fn get_tokenizer(
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None
)
}
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
// {
// HubTokenizerConfig::from_file(filename)
// } else {
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
// };
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
// tracing::warn!("Could not find tokenizer config locally and no API specified");
// HubTokenizerConfig::default()
// });
let tokenizer: Tokenizer = {
use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
})
.or_else(|err| {
let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err)
})
.expect("We cannot load a tokenizer");
let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
Tokenizer::Python {
tokenizer_name: tokenizer_name.to_string(),
revision: revision.map(|revision| revision.to_string()),
trust_remote_code: false,
}
}
};
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
Some(tokenizer)
}
#[tokio::main]
@ -258,14 +291,17 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
}
// Create the backend
let tokenizer = get_tokenizer(
match get_tokenizer(
&tokenizer_name,
tokenizer_config_path.as_deref(),
revision.as_deref(),
)
.await
.expect("Failed to retrieve tokenizer implementation");
.expect("Failed to retrieve tokenizer implementation") {
Tokenizer::Python { .. } => {
Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string()))
}
Tokenizer::Rust(tokenizer) => {
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new(
tokenizer,
@ -301,7 +337,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
).await?;
Ok(())
}
}
}

View File

@ -1588,7 +1588,7 @@ pub fn schema() -> ApiDoc {
ApiDoc
}
fn py_resolve_tokenizer(
pub fn py_resolve_tokenizer(
py: pyo3::Python,
tokenizer_name: &str,
revision: Option<&str>,
@ -1614,7 +1614,7 @@ fn py_resolve_tokenizer(
Ok(())
}
fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
pub fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
// XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3
// and state-spaces/mamba-130m
tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");