fix: follow base model for tokenizer in router (#1424)

Close #1422
This commit is contained in:
OlivierDehaene 2024-01-10 16:35:54 +01:00 committed by Karol Damaszke
parent 92ddb41d95
commit af63e3273f
3 changed files with 189 additions and 163 deletions

4
Cargo.lock generated
View File

@ -983,13 +983,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
dependencies = [ dependencies = [
"dirs 5.0.1", "dirs 5.0.1",
"futures",
"indicatif", "indicatif",
"log", "log",
"native-tls", "native-tls",
"num_cpus",
"rand", "rand",
"reqwest",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror",
"tokio",
"ureq", "ureq",
] ]

View File

@ -21,6 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { version = "0.3.0", features = ["tokio"] }
metrics = "0.21.1" metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] } metrics-exporter-prometheus = { version = "0.12.1", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
@ -41,7 +42,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "3.5.0", features = ["axum_extras"] } utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
hf-hub = "0.3.1"
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
[build-dependencies] [build-dependencies]

View File

@ -1,8 +1,9 @@
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. /// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
/// Text Generation Inference webserver entrypoint
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::trace::Sampler;
@ -10,13 +11,15 @@ use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue}; use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use std::env; use std::env;
/// Text Generation Inference webserver entrypoint
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::time::Duration;
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo}; use text_generation_router::{server, HubModelInfo};
use thiserror::Error; use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -72,7 +75,8 @@ struct Args {
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
} }
fn main() -> Result<(), RouterError> { #[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
@ -101,6 +105,9 @@ fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
} = args; } = args;
// Launch Tokio runtime
init_logging(otlp_endpoint, json_output);
// Validate args // Validate args
if max_input_length >= max_total_tokens { if max_input_length >= max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
@ -147,57 +154,66 @@ fn main() -> Result<(), RouterError> {
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI") let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI")
.ok() .ok()
.map_or(false, |value| value.to_lowercase() == "true"); .map_or(false, |value| value.to_lowercase() == "true");
let tokenizer = if skip_tokenizer_in_tgi { let (tokenizer, model_info) = if local_model {
None
} else if local_model {
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters {
revision: revision.clone().unwrap_or("main".to_string()),
auth_token: authorization_token.clone(),
..Default::default()
};
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
};
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
init_logging(otlp_endpoint, json_output);
if skip_tokenizer_in_tgi {
tracing::warn!("Rust input length validation disabled by environment variable");
} else if tokenizer.is_none() {
tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
);
tracing::warn!("Rust input length validation and truncation is disabled");
}
// Get Model info // Get Model info
let model_info = match local_model { let model_info = HubModelInfo {
true => HubModelInfo {
model_id: tokenizer_name.clone(), model_id: tokenizer_name.clone(),
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, };
false => get_model_info(&tokenizer_name, revision, authorization_token)
.await // Load local tokenizer
.unwrap_or_else(|| { let tokenizer = if skip_tokenizer_in_tgi {
None
} else {
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
};
(tokenizer, model_info)
} else {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() {
builder = builder.with_cache_dir(cache_dir.into());
}
if revision.is_none() {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
}
let api = builder.build().unwrap();
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.clone(),
RepoType::Model,
revision.clone().unwrap_or("main".to_string()),
));
// Get Model info
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { HubModelInfo {
model_id: tokenizer_name.to_string(), model_id: tokenizer_name.to_string(),
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
} }
}), });
let tokenizer = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
}; };
(tokenizer, model_info)
};
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true // if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag { let compat_return_full_text = match &model_info.pipeline_tag {
None => { None => {
@ -222,15 +238,19 @@ fn main() -> Result<(), RouterError> {
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, max_batch_total_tokens) .warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_total_tokens,
)
.await .await
.map_err(RouterError::Warmup)? .map_err(RouterError::Warmup)?
{ {
// Older models do not support automatic max-batch-total-tokens // Older models do not support automatic max-batch-total-tokens
None => { None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( let max_batch_total_tokens = max_batch_total_tokens
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
);
tracing::warn!("Model does not support automatic max batch total tokens"); tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens max_batch_total_tokens
} }
@ -290,7 +310,6 @@ fn main() -> Result<(), RouterError> {
) )
.await?; .await?;
Ok(()) Ok(())
})
} }
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
@ -349,30 +368,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info( pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
model_id: &str, let response = api.info_request().send().await.ok()?;
revision: Option<String>,
token: Option<String>,
) -> Option<HubModelInfo> {
let revision = match revision {
None => {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
"main".to_string()
}
Some(revision) => revision,
};
let client = reqwest::Client::new();
// Poor man's urlencode
let revision = revision.replace('/', "%2F");
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
let mut builder = client.get(url).timeout(Duration::from_secs(5));
if let Some(token) = token {
builder = builder.bearer_auth(token);
}
let response = builder.send().await.ok()?;
if response.status().is_success() { if response.status().is_success() {
let hub_model_info: HubModelInfo = let hub_model_info: HubModelInfo =
@ -389,6 +386,31 @@ pub async fn get_model_info(
} }
} }
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
Tokenizer::from_file(tokenizer_filename).ok()
} else {
None
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
enum RouterError { enum RouterError {
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]