diff --git a/router/src/lib.rs b/router/src/lib.rs index fc5670a0..07360e78 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -37,7 +37,7 @@ pub struct HubTokenizerConfig { } impl HubTokenizerConfig { - pub fn from_file(filename: &str) -> Self { + pub fn from_file(filename: &std::path::Path) -> Self { let content = std::fs::read_to_string(filename).unwrap(); serde_json::from_str(&content).unwrap_or_default() } diff --git a/router/src/main.rs b/router/src/main.rs index 495fd5bc..9afa5727 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -156,9 +156,14 @@ async fn main() -> Result<(), RouterError> { // Load tokenizer config // This will be used to format the chat template - let local_tokenizer_config_path = - tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); - let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); + let tokenizer_config_full_path = if tokenizer_config_path.is_none() && local_model { + // if no tokenizer config path is provided, we default to the local tokenizer config + Some(local_path.join("tokenizer_config.json")) + } else if let Some(tokenizer_config_path) = tokenizer_config_path { + Some(std::path::PathBuf::from(tokenizer_config_path)) + } else { + None + }; // Shared API builder initialization let api_builder = || { @@ -230,24 +235,33 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer config if found locally, or check if we can get it from the API if needed - let tokenizer_config = if local_tokenizer_config { - tracing::info!("Using local tokenizer config"); - HubTokenizerConfig::from_file(&local_tokenizer_config_path) - } else if let Some(api) = api { - tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); - get_tokenizer_config(&api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.unwrap_or_else(|| "main".to_string()), - ))) - .await - .unwrap_or_else(|| { - tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); - HubTokenizerConfig::default() - }) - } else { - tracing::warn!("Could not find tokenizer config locally and no revision specified"); - HubTokenizerConfig::default() + let tokenizer_config = match tokenizer_config_full_path { + Some(path) => { + tracing::info!("Using local tokenizer config"); + HubTokenizerConfig::from_file(&path) + } + None => match api { + Some(api) => { + tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); + let repo = Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or("main".to_string()), + ); + get_tokenizer_config(&api.repo(repo)) + .await + .unwrap_or_else(|| { + tracing::warn!( + "Could not retrieve tokenizer config from the Hugging Face hub." + ); + HubTokenizerConfig::default() + }) + } + None => { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + } + }, }; if tokenizer.is_none() {