diff --git a/router/src/lib.rs b/router/src/lib.rs index 9213657b..f6f8276f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -29,12 +29,19 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Clone, Deserialize)] +#[derive(Clone, Deserialize, Default)] pub struct HubTokenizerConfig { #[serde(default)] pub chat_template: Option, } +impl HubTokenizerConfig { + pub fn from_file(filename: &str) -> Self { + let content = std::fs::read_to_string(filename).unwrap(); + serde_json::from_str(&content).unwrap_or_default() + } +} + #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info diff --git a/router/src/main.rs b/router/src/main.rs index e9319916..6c6445bb 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -12,7 +12,6 @@ use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; -use std::time::Duration; use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use thiserror::Error; @@ -55,6 +54,8 @@ struct Args { #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] revision: Option, #[clap(default_value = "2", long, env)] validation_workers: usize, @@ -92,6 +93,7 @@ async fn main() -> Result<(), RouterError> { port, master_shard_uds_path, tokenizer_name, + tokenizer_config_path, revision, validation_workers, json_output, @@ -149,37 +151,14 @@ async fn main() -> Result<(), RouterError> { let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - let tokenizer_config: HubTokenizerConfig = match local_model { - true => HubTokenizerConfig { - chat_template: None, - }, - false => get_tokenizer_config( - &tokenizer_name, - revision.as_deref(), - authorization_token.as_deref(), - ) - .await - .unwrap_or_else(|| { - tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); - HubTokenizerConfig { - chat_template: None, - } - }), - }; + // Load tokenizer config + // This will be used to format the chat template + let local_tokenizer_config_path = + tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); + let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); - let (tokenizer, model_info) = if local_model { - // Get Model info - let model_info = HubModelInfo { - model_id: tokenizer_name.clone(), - sha: None, - pipeline_tag: None, - }; - - // Load local tokenizer - let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); - - (tokenizer, model_info) - } else { + // Shared API builder initialization + let api_builder = || { let mut builder = ApiBuilder::new() .with_progress(false) .with_token(authorization_token); @@ -188,19 +167,42 @@ async fn main() -> Result<(), RouterError> { builder = builder.with_cache_dir(cache_dir.into()); } - if revision.is_none() { - tracing::warn!("`--revision` is not set"); - tracing::warn!("We strongly advise to set it to a known supported commit."); - } + builder + }; - let api = builder.build().unwrap(); + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + let api = if use_api { + tracing::info!("Using the Hugging Face API"); + Some(api_builder().build().unwrap()) + } else { + None + }; + + // Load tokenizer and model info + let (tokenizer, model_info) = if local_model { + let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); + let model_info = HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }; + + (tokenizer, model_info) + } else if let Some(api) = api.clone() { let api_repo = api.repo(Repo::with_revision( - tokenizer_name.clone(), + tokenizer_name.to_string(), RepoType::Model, - revision.clone().unwrap_or("main".to_string()), + revision.clone().unwrap_or_else(|| "main".to_string()), )); - // Get Model info + let tokenizer = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| { tracing::warn!("Could not retrieve model info from the Hugging Face hub."); HubModelInfo { @@ -210,12 +212,30 @@ async fn main() -> Result<(), RouterError> { } }); - let tokenizer = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; - (tokenizer, model_info) + } else { + // No API and no local model + return Err(RouterError::ArgumentValidation( + "No local model found and no revision specified".to_string(), + )); + }; + + // Load tokenizer config if found locally, or check if we can get it from the API if needed + let tokenizer_config = if local_tokenizer_config { + tracing::info!("Using local tokenizer config"); + HubTokenizerConfig::from_file(&local_tokenizer_config_path) + } else if let Some(api) = api { + tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); + get_tokenizer_config(&api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or_else(|| "main".to_string()), + ))) + .await + .unwrap() + } else { + tracing::warn!("Could not find tokenizer config locally and no revision specified"); + HubTokenizerConfig::default() }; if tokenizer.is_none() { @@ -421,38 +441,17 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option, - token: Option<&str>, -) -> Option { - let revision = match revision { - None => { - tracing::warn!("`--revision` is not set"); - tracing::warn!("We strongly advise to set it to a known supported commit."); - "main".to_string() - } - Some(revision) => revision.to_string(), - }; - let client = reqwest::Client::new(); - // Poor man's urlencode - let revision = revision.replace('/', "%2F"); - let url = format!( - "https://huggingface.co/{}/raw/{}/tokenizer_config.json", - model_id, revision - ); - let mut builder = client.get(url).timeout(Duration::from_secs(5)); - if let Some(token) = token { - builder = builder.bearer_auth(token); - } - let response = builder.send().await.ok()?; - if response.status().is_success() { - let text = response.text().await.ok()?; - let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?; - Some(hub_tokenizer_config) - } else { - None - } +pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?; + + Some(tokenizer_config) } #[derive(Debug, Error)]