fix: tokenizer config should use local model path when possible

This commit is contained in:
drbh 2024-02-01 13:54:33 +00:00
parent 9ad7b6a1a1
commit 6e08f5b265
2 changed files with 36 additions and 22 deletions

View File

@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
pub fn from_file(filename: &str) -> Self { pub fn from_file(filename: &std::path::Path) -> Self {
let content = std::fs::read_to_string(filename).unwrap(); let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default() serde_json::from_str(&content).unwrap_or_default()
} }

View File

@ -156,9 +156,14 @@ async fn main() -> Result<(), RouterError> {
// Load tokenizer config // Load tokenizer config
// This will be used to format the chat template // This will be used to format the chat template
let local_tokenizer_config_path = let tokenizer_config_full_path = if tokenizer_config_path.is_none() && local_model {
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); // if no tokenizer config path is provided, we default to the local tokenizer config
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); Some(local_path.join("tokenizer_config.json"))
} else if let Some(tokenizer_config_path) = tokenizer_config_path {
Some(std::path::PathBuf::from(tokenizer_config_path))
} else {
None
};
// Shared API builder initialization // Shared API builder initialization
let api_builder = || { let api_builder = || {
@ -230,24 +235,33 @@ async fn main() -> Result<(), RouterError> {
}; };
// Load tokenizer config if found locally, or check if we can get it from the API if needed // Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if local_tokenizer_config { let tokenizer_config = match tokenizer_config_full_path {
Some(path) => {
tracing::info!("Using local tokenizer config"); tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_tokenizer_config_path) HubTokenizerConfig::from_file(&path)
} else if let Some(api) = api { }
None => match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
get_tokenizer_config(&api.repo(Repo::with_revision( let repo = Repo::with_revision(
tokenizer_name.to_string(), tokenizer_name.to_string(),
RepoType::Model, RepoType::Model,
revision.unwrap_or_else(|| "main".to_string()), revision.unwrap_or("main".to_string()),
))) );
get_tokenizer_config(&api.repo(repo))
.await .await
.unwrap_or_else(|| { .unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default() HubTokenizerConfig::default()
}) })
} else { }
tracing::warn!("Could not find tokenizer config locally and no revision specified"); None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default() HubTokenizerConfig::default()
}
},
}; };
if tokenizer.is_none() { if tokenizer.is_none() {