fix: tokenizer config should use local model path when possible

This commit is contained in:
drbh 2024-02-01 13:54:33 +00:00
parent 9ad7b6a1a1
commit 6e08f5b265
2 changed files with 36 additions and 22 deletions

View File

@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
pub fn from_file(filename: &str) -> Self { pub fn from_file(filename: &std::path::Path) -> Self {
let content = std::fs::read_to_string(filename).unwrap(); let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default() serde_json::from_str(&content).unwrap_or_default()
} }

View File

@ -156,9 +156,14 @@ async fn main() -> Result<(), RouterError> {
// Load tokenizer config // Load tokenizer config
// This will be used to format the chat template // This will be used to format the chat template
let local_tokenizer_config_path = let tokenizer_config_full_path = if tokenizer_config_path.is_none() && local_model {
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); // if no tokenizer config path is provided, we default to the local tokenizer config
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); Some(local_path.join("tokenizer_config.json"))
} else if let Some(tokenizer_config_path) = tokenizer_config_path {
Some(std::path::PathBuf::from(tokenizer_config_path))
} else {
None
};
// Shared API builder initialization // Shared API builder initialization
let api_builder = || { let api_builder = || {
@ -230,24 +235,33 @@ async fn main() -> Result<(), RouterError> {
}; };
// Load tokenizer config if found locally, or check if we can get it from the API if needed // Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if local_tokenizer_config { let tokenizer_config = match tokenizer_config_full_path {
tracing::info!("Using local tokenizer config"); Some(path) => {
HubTokenizerConfig::from_file(&local_tokenizer_config_path) tracing::info!("Using local tokenizer config");
} else if let Some(api) = api { HubTokenizerConfig::from_file(&path)
tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); }
get_tokenizer_config(&api.repo(Repo::with_revision( None => match api {
tokenizer_name.to_string(), Some(api) => {
RepoType::Model, tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
revision.unwrap_or_else(|| "main".to_string()), let repo = Repo::with_revision(
))) tokenizer_name.to_string(),
.await RepoType::Model,
.unwrap_or_else(|| { revision.unwrap_or("main".to_string()),
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); );
HubTokenizerConfig::default() get_tokenizer_config(&api.repo(repo))
}) .await
} else { .unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no revision specified"); tracing::warn!(
HubTokenizerConfig::default() "Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
},
}; };
if tokenizer.is_none() { if tokenizer.is_none() {