Using @drbh patch.

This commit is contained in:
Nicolas Patry 2024-04-23 14:27:43 +00:00
parent af24703708
commit 429092683b

View File

@ -187,12 +187,12 @@ async fn main() -> Result<(), RouterError> {
None,
}
let api = if use_api {
tracing::info!("Using the Hugging Face API");
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = Cache::default();
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
@ -207,18 +207,12 @@ async fn main() -> Result<(), RouterError> {
// Load tokenizer and model info
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
Type::None => {
let tokenizer_filename = Some(local_path.join("tokenizer.json"));
let config_filename = Some(local_path.join("config.json"));
let tokenizer_config_filename = Some(local_path.join("tokenizer_config.json"));
let model_info = None;
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
model_info,
)
}
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
@ -247,21 +241,16 @@ async fn main() -> Result<(), RouterError> {
)
}
Type::Cache(cache) => {
let cache_repo = cache.repo(Repo::with_revision(
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = cache_repo.get("tokenizer.json");
let config_filename = cache_repo.get("config.json");
let tokenizer_config_filename = cache_repo.get("tokenizer_config.json");
let model_info = None;
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
model_info,
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
None,
)
}
};