Adding support for HF_HUB_OFFLINE support in the router.

This commit is contained in:
Nicolas Patry 2024-04-22 13:53:08 +00:00
parent bfddfa5955
commit af24703708
2 changed files with 111 additions and 101 deletions

View File

@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
pub fn from_file(filename: &std::path::Path) -> Self { pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).unwrap(); let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).unwrap_or_default() serde_json::from_str(&content).ok()
} }
} }

View File

@ -1,7 +1,7 @@
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::trace::Sampler;
@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
use std::fs::File; use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
// Shared API builder initialization // Shared API builder initialization
let api_builder = || { let api_builder = || {
@ -181,46 +180,94 @@ async fn main() -> Result<(), RouterError> {
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed // Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api { let api = if use_api {
tracing::info!("Using the Hugging Face API"); tracing::info!("Using the Hugging Face API");
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = Cache::default();
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
match api_builder().build() { match api_builder().build() {
Ok(api) => Some(api), Ok(api) => Type::Api(api),
Err(_) => { Err(_) => {
tracing::warn!("Unable to build the Hugging Face API"); tracing::warn!("Unable to build the Hugging Face API");
None Type::None
}
} }
} }
} else { } else {
None Type::None
}; };
// Load tokenizer and model info // Load tokenizer and model info
let (tokenizer, model_info, config) = if local_model { let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); Type::None => {
let model_info = HubModelInfo { let tokenizer_filename = Some(local_path.join("tokenizer.json"));
model_id: tokenizer_name.to_string(), let config_filename = Some(local_path.join("config.json"));
sha: None, let tokenizer_config_filename = Some(local_path.join("tokenizer_config.json"));
pipeline_tag: None, let model_info = None;
}; (
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json")) tokenizer_filename,
.ok() config_filename,
.as_ref() tokenizer_config_filename,
.and_then(|c| serde_json::from_str(c).ok()); model_info,
)
(tokenizer, model_info, config) }
} else if let Some(api) = api.clone() { Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision( let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(), tokenizer_name.to_string(),
RepoType::Model, RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()), revision.clone().unwrap_or_else(|| "main".to_string()),
)); ));
let tokenizer = match api_repo.get("tokenizer.json").await { let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await, Err(_) => get_base_tokenizer(&api, &api_repo).await,
}; };
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| { let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let cache_repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = cache_repo.get("tokenizer.json");
let config_filename = cache_repo.get("config.json");
let tokenizer_config_filename = cache_repo.get("tokenizer_config.json");
let model_info = None;
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
model_info,
)
}
};
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename) std::fs::read_to_string(filename)
.ok() .ok()
.as_ref() .as_ref()
@ -232,61 +279,25 @@ async fn main() -> Result<(), RouterError> {
config.ok() config.ok()
}) })
}); });
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(), model_id: tokenizer_name.to_string(),
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}
}); });
(tokenizer, model_info, config) // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else { } else {
// No API and no local model tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
return Err(RouterError::ArgumentValidation(
"No local model found and no revision specified".to_string(),
));
}; };
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::info!("Using config {config:?}");
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!(
"Using local tokenizer config from user specified path {}",
path
);
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
} else if local_model {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
} else {
match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified"); tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default() HubTokenizerConfig::default()
} });
}
};
tracing::info!("Using config {config:?}");
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled"); tracing::warn!("Rust input length validation and truncation is disabled");
@ -483,7 +494,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
} }
/// get base tokenizer /// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> { pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?; let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer. // Open the file in read-only mode with buffer.
@ -500,8 +511,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
"main".to_string(), "main".to_string(),
)); ));
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?; api_base_repo.get("tokenizer.json").await.ok()
Tokenizer::from_file(tokenizer_filename).ok()
} else { } else {
None None
} }