fix: tokenizer config should use local model path when possible

This commit is contained in:
drbh 2024-02-01 13:54:33 +00:00
parent 9ad7b6a1a1
commit 6e08f5b265
2 changed files with 36 additions and 22 deletions

View File

@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
}
impl HubTokenizerConfig {
pub fn from_file(filename: &str) -> Self {
pub fn from_file(filename: &std::path::Path) -> Self {
let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default()
}

View File

@ -156,9 +156,14 @@ async fn main() -> Result<(), RouterError> {
// Load tokenizer config
// This will be used to format the chat template
let local_tokenizer_config_path =
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
let tokenizer_config_full_path = if tokenizer_config_path.is_none() && local_model {
// if no tokenizer config path is provided, we default to the local tokenizer config
Some(local_path.join("tokenizer_config.json"))
} else if let Some(tokenizer_config_path) = tokenizer_config_path {
Some(std::path::PathBuf::from(tokenizer_config_path))
} else {
None
};
// Shared API builder initialization
let api_builder = || {
@ -230,24 +235,33 @@ async fn main() -> Result<(), RouterError> {
};
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if local_tokenizer_config {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
} else if let Some(api) = api {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
get_tokenizer_config(&api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main".to_string()),
)))
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
HubTokenizerConfig::default()
})
} else {
tracing::warn!("Could not find tokenizer config locally and no revision specified");
HubTokenizerConfig::default()
let tokenizer_config = match tokenizer_config_full_path {
Some(path) => {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&path)
}
None => match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
},
};
if tokenizer.is_none() {