mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: tokenizer config should use local model path when possible
This commit is contained in:
parent
9ad7b6a1a1
commit
6e08f5b265
@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
pub fn from_file(filename: &str) -> Self {
|
pub fn from_file(filename: &std::path::Path) -> Self {
|
||||||
let content = std::fs::read_to_string(filename).unwrap();
|
let content = std::fs::read_to_string(filename).unwrap();
|
||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
@ -156,9 +156,14 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
|
|
||||||
// Load tokenizer config
|
// Load tokenizer config
|
||||||
// This will be used to format the chat template
|
// This will be used to format the chat template
|
||||||
let local_tokenizer_config_path =
|
let tokenizer_config_full_path = if tokenizer_config_path.is_none() && local_model {
|
||||||
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
|
// if no tokenizer config path is provided, we default to the local tokenizer config
|
||||||
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
|
Some(local_path.join("tokenizer_config.json"))
|
||||||
|
} else if let Some(tokenizer_config_path) = tokenizer_config_path {
|
||||||
|
Some(std::path::PathBuf::from(tokenizer_config_path))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// Shared API builder initialization
|
// Shared API builder initialization
|
||||||
let api_builder = || {
|
let api_builder = || {
|
||||||
@ -230,24 +235,33 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||||
let tokenizer_config = if local_tokenizer_config {
|
let tokenizer_config = match tokenizer_config_full_path {
|
||||||
tracing::info!("Using local tokenizer config");
|
Some(path) => {
|
||||||
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
|
tracing::info!("Using local tokenizer config");
|
||||||
} else if let Some(api) = api {
|
HubTokenizerConfig::from_file(&path)
|
||||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
}
|
||||||
get_tokenizer_config(&api.repo(Repo::with_revision(
|
None => match api {
|
||||||
tokenizer_name.to_string(),
|
Some(api) => {
|
||||||
RepoType::Model,
|
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||||
revision.unwrap_or_else(|| "main".to_string()),
|
let repo = Repo::with_revision(
|
||||||
)))
|
tokenizer_name.to_string(),
|
||||||
.await
|
RepoType::Model,
|
||||||
.unwrap_or_else(|| {
|
revision.unwrap_or("main".to_string()),
|
||||||
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
);
|
||||||
HubTokenizerConfig::default()
|
get_tokenizer_config(&api.repo(repo))
|
||||||
})
|
.await
|
||||||
} else {
|
.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not find tokenizer config locally and no revision specified");
|
tracing::warn!(
|
||||||
HubTokenizerConfig::default()
|
"Could not retrieve tokenizer config from the Hugging Face hub."
|
||||||
|
);
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
|
Loading…
Reference in New Issue
Block a user