mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
feat(backend): fix main.rs retrieving the tokenizer
This commit is contained in:
parent
874bc28d6c
commit
16ba2f5a2b
55
Cargo.lock
generated
55
Cargo.lock
generated
@ -2850,20 +2850,6 @@ dependencies = [
|
|||||||
"urlencoding",
|
"urlencoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "opentelemetry"
|
|
||||||
version = "0.24.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
|
|
||||||
dependencies = [
|
|
||||||
"futures-core",
|
|
||||||
"futures-sink",
|
|
||||||
"js-sys",
|
|
||||||
"once_cell",
|
|
||||||
"pin-project-lite",
|
|
||||||
"thiserror",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opentelemetry-otlp"
|
name = "opentelemetry-otlp"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
@ -2963,24 +2949,6 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "opentelemetry_sdk"
|
|
||||||
version = "0.24.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
|
|
||||||
dependencies = [
|
|
||||||
"async-trait",
|
|
||||||
"futures-channel",
|
|
||||||
"futures-executor",
|
|
||||||
"futures-util",
|
|
||||||
"glob",
|
|
||||||
"once_cell",
|
|
||||||
"opentelemetry 0.24.0",
|
|
||||||
"percent-encoding",
|
|
||||||
"rand",
|
|
||||||
"thiserror",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "option-ext"
|
name = "option-ext"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
@ -4369,7 +4337,6 @@ dependencies = [
|
|||||||
name = "text-generation-backends-trtllm"
|
name = "text-generation-backends-trtllm"
|
||||||
version = "2.4.2-dev0"
|
version = "2.4.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
"cmake",
|
"cmake",
|
||||||
@ -4377,16 +4344,14 @@ dependencies = [
|
|||||||
"cxx-build",
|
"cxx-build",
|
||||||
"hashbrown 0.14.5",
|
"hashbrown 0.14.5",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"log",
|
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
|
"pyo3",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.25.0",
|
|
||||||
"tracing-subscriber",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -5086,24 +5051,6 @@ dependencies = [
|
|||||||
"web-time 0.2.4",
|
"web-time 0.2.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tracing-opentelemetry"
|
|
||||||
version = "0.25.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
|
|
||||||
dependencies = [
|
|
||||||
"js-sys",
|
|
||||||
"once_cell",
|
|
||||||
"opentelemetry 0.24.0",
|
|
||||||
"opentelemetry_sdk 0.24.1",
|
|
||||||
"smallvec",
|
|
||||||
"tracing",
|
|
||||||
"tracing-core",
|
|
||||||
"tracing-log 0.2.0",
|
|
||||||
"tracing-subscriber",
|
|
||||||
"web-time 1.1.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-opentelemetry-instrumentation-sdk"
|
name = "tracing-opentelemetry-instrumentation-sdk"
|
||||||
version = "0.16.0"
|
version = "0.16.0"
|
||||||
|
@ -7,20 +7,21 @@ homepage.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
async-stream = "0.3"
|
#async-stream = "0.3"
|
||||||
clap = { version = "4.5", features = ["derive"] }
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
hashbrown = "0.14"
|
hashbrown = "0.14"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
log = { version = "0.4", features = [] }
|
#log = { version = "0.4", features = [] }
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.15"
|
tokio-stream = "0.1.15"
|
||||||
thiserror = "1.0.63"
|
thiserror = "1.0.63"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-opentelemetry = "0.25"
|
#tracing-opentelemetry = "0.25"
|
||||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||||
|
pyo3 = { workspace = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cmake = "0.1"
|
cmake = "0.1"
|
||||||
|
@ -3,14 +3,13 @@ use std::path::{Path, PathBuf};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder};
|
use hf_hub::api::tokio::{Api, ApiBuilder};
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||||
use text_generation_router::server::get_base_tokenizer;
|
|
||||||
use text_generation_router::usage_stats::UsageStatsLevel;
|
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||||
use text_generation_router::{server, HubTokenizerConfig};
|
use text_generation_router::{server, HubTokenizerConfig, Tokenizer};
|
||||||
|
use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -61,7 +60,7 @@ struct Args {
|
|||||||
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||||
executor_worker: PathBuf,
|
executor_worker: PathBuf,
|
||||||
#[clap(default_value = "on", long, env)]
|
#[clap(default_value = "on", long, env)]
|
||||||
usage_stats: usage_stats::UsageStatsLevel,
|
usage_stats: UsageStatsLevel,
|
||||||
#[clap(default_value = "2000000", long, env)]
|
#[clap(default_value = "2000000", long, env)]
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
@ -126,18 +125,18 @@ async fn get_tokenizer(
|
|||||||
|
|
||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let (
|
let (
|
||||||
tokenizer_filename,
|
config_filename,
|
||||||
_config_filename,
|
_tokenizer_config_filename,
|
||||||
tokenizer_config_filename,
|
|
||||||
_preprocessor_config_filename,
|
_preprocessor_config_filename,
|
||||||
_processor_config_filename,
|
_processor_config_filename,
|
||||||
|
_model_info
|
||||||
) = match api {
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
Some(local_path.join("tokenizer.json")),
|
|
||||||
Some(local_path.join("config.json")),
|
Some(local_path.join("config.json")),
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
Some(local_path.join("processor_config.json")),
|
Some(local_path.join("processor_config.json")),
|
||||||
|
None
|
||||||
),
|
),
|
||||||
Type::Api(api) => {
|
Type::Api(api) => {
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
@ -146,21 +145,24 @@ async fn get_tokenizer(
|
|||||||
revision.unwrap_or_else(|| "main").to_string(),
|
revision.unwrap_or_else(|| "main").to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
|
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||||
|
Some(model_info)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
None
|
||||||
|
};
|
||||||
(
|
(
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
preprocessor_config_filename,
|
preprocessor_config_filename,
|
||||||
processor_config_filename,
|
processor_config_filename,
|
||||||
|
model_info,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Type::Cache(cache) => {
|
Type::Cache(cache) => {
|
||||||
@ -170,24 +172,55 @@ async fn get_tokenizer(
|
|||||||
revision.clone().unwrap_or_else(|| "main").to_string(),
|
revision.clone().unwrap_or_else(|| "main").to_string(),
|
||||||
));
|
));
|
||||||
(
|
(
|
||||||
repo.get("tokenizer.json"),
|
|
||||||
repo.get("config.json"),
|
repo.get("config.json"),
|
||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
repo.get("preprocessor_config.json"),
|
repo.get("preprocessor_config.json"),
|
||||||
repo.get("processor_config.json"),
|
repo.get("processor_config.json"),
|
||||||
|
None
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
{
|
// {
|
||||||
HubTokenizerConfig::from_file(filename)
|
// HubTokenizerConfig::from_file(filename)
|
||||||
} else {
|
// } else {
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
|
// };
|
||||||
|
|
||||||
|
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
|
// tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
// HubTokenizerConfig::default()
|
||||||
|
// });
|
||||||
|
|
||||||
|
let tokenizer: Tokenizer = {
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||||
|
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.inspect_err(|err| {
|
||||||
|
tracing::error!("Failed to import python tokenizer {err}");
|
||||||
|
})
|
||||||
|
.or_else(|err| {
|
||||||
|
let out = legacy_tokenizer_handle(config_filename.as_ref());
|
||||||
|
out.ok_or(err)
|
||||||
|
})
|
||||||
|
.expect("We cannot load a tokenizer");
|
||||||
|
let filename = "out/tokenizer.json";
|
||||||
|
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
|
||||||
|
Tokenizer::Rust(tok)
|
||||||
|
} else {
|
||||||
|
Tokenizer::Python {
|
||||||
|
tokenizer_name: tokenizer_name.to_string(),
|
||||||
|
revision: revision.map(|revision| revision.to_string()),
|
||||||
|
trust_remote_code: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
|
Some(tokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -258,50 +291,55 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the backend
|
// Create the backend
|
||||||
let tokenizer = get_tokenizer(
|
match get_tokenizer(
|
||||||
&tokenizer_name,
|
&tokenizer_name,
|
||||||
tokenizer_config_path.as_deref(),
|
tokenizer_config_path.as_deref(),
|
||||||
revision.as_deref(),
|
revision.as_deref(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve tokenizer implementation");
|
.expect("Failed to retrieve tokenizer implementation") {
|
||||||
|
Tokenizer::Python { .. } => {
|
||||||
|
Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string()))
|
||||||
|
}
|
||||||
|
Tokenizer::Rust(tokenizer) => {
|
||||||
|
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
|
||||||
|
let backend = TensorRtLlmBackendV2::new(
|
||||||
|
tokenizer,
|
||||||
|
model_id,
|
||||||
|
executor_worker,
|
||||||
|
max_concurrent_requests,
|
||||||
|
)?;
|
||||||
|
|
||||||
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
|
info!("Successfully created backend");
|
||||||
let backend = TensorRtLlmBackendV2::new(
|
|
||||||
tokenizer,
|
|
||||||
model_id,
|
|
||||||
executor_worker,
|
|
||||||
max_concurrent_requests,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
info!("Successfully created backend");
|
// Run server
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
validation_workers,
|
||||||
|
auth_token,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
false,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
cors_allow_origin,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
true,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
payload_limit,
|
||||||
|
).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Run server
|
|
||||||
server::run(
|
|
||||||
backend,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
validation_workers,
|
|
||||||
auth_token,
|
|
||||||
tokenizer_name,
|
|
||||||
tokenizer_config_path,
|
|
||||||
revision,
|
|
||||||
false,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
cors_allow_origin,
|
|
||||||
false,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
true,
|
|
||||||
max_client_batch_size,
|
|
||||||
usage_stats,
|
|
||||||
payload_limit,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
@ -1588,7 +1588,7 @@ pub fn schema() -> ApiDoc {
|
|||||||
ApiDoc
|
ApiDoc
|
||||||
}
|
}
|
||||||
|
|
||||||
fn py_resolve_tokenizer(
|
pub fn py_resolve_tokenizer(
|
||||||
py: pyo3::Python,
|
py: pyo3::Python,
|
||||||
tokenizer_name: &str,
|
tokenizer_name: &str,
|
||||||
revision: Option<&str>,
|
revision: Option<&str>,
|
||||||
@ -1614,7 +1614,7 @@ fn py_resolve_tokenizer(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
|
pub fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
|
||||||
// XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3
|
// XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3
|
||||||
// and state-spaces/mamba-130m
|
// and state-spaces/mamba-130m
|
||||||
tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");
|
tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");
|
||||||
|
Loading…
Reference in New Issue
Block a user