mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Adding support for HF_HUB_OFFLINE
support in the router.
This commit is contained in:
parent
bfddfa5955
commit
af24703708
@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
|
||||
}
|
||||
|
||||
impl HubTokenizerConfig {
|
||||
pub fn from_file(filename: &std::path::Path) -> Self {
|
||||
let content = std::fs::read_to_string(filename).unwrap();
|
||||
serde_json::from_str(&content).unwrap_or_default()
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use axum::http::HeaderValue;
|
||||
use clap::Parser;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
use opentelemetry::sdk::trace::Sampler;
|
||||
@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
|
||||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
let local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
@ -181,112 +180,124 @@ async fn main() -> Result<(), RouterError> {
|
||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||
|
||||
// Initialize API if needed
|
||||
#[derive(Clone)]
|
||||
enum Type {
|
||||
Api(Api),
|
||||
Cache(Cache),
|
||||
None,
|
||||
}
|
||||
let api = if use_api {
|
||||
tracing::info!("Using the Hugging Face API");
|
||||
match api_builder().build() {
|
||||
Ok(api) => Some(api),
|
||||
Err(_) => {
|
||||
tracing::warn!("Unable to build the Hugging Face API");
|
||||
None
|
||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||
let cache = Cache::default();
|
||||
tracing::warn!("Offline mode active using cache defaults");
|
||||
Type::Cache(cache)
|
||||
} else {
|
||||
match api_builder().build() {
|
||||
Ok(api) => Type::Api(api),
|
||||
Err(_) => {
|
||||
tracing::warn!("Unable to build the Hugging Face API");
|
||||
Type::None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
Type::None
|
||||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (tokenizer, model_info, config) = if local_model {
|
||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||
let model_info = HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| serde_json::from_str(c).ok());
|
||||
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
||||
Type::None => {
|
||||
let tokenizer_filename = Some(local_path.join("tokenizer.json"));
|
||||
let config_filename = Some(local_path.join("config.json"));
|
||||
let tokenizer_config_filename = Some(local_path.join("tokenizer_config.json"));
|
||||
let model_info = None;
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
Type::Api(api) => {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
|
||||
(tokenizer, model_info, config)
|
||||
} else if let Some(api) = api.clone() {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
|
||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||
Some(model_info)
|
||||
} else {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
None
|
||||
};
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
let cache_repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
|
||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
|
||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
}
|
||||
});
|
||||
|
||||
(tokenizer, model_info, config)
|
||||
} else {
|
||||
// No API and no local model
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"No local model found and no revision specified".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
|
||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||
tracing::info!(
|
||||
"Using local tokenizer config from user specified path {}",
|
||||
path
|
||||
);
|
||||
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
||||
} else if local_model {
|
||||
tracing::info!("Using local tokenizer config");
|
||||
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
||||
} else {
|
||||
match api {
|
||||
Some(api) => {
|
||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||
let repo = Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.unwrap_or("main".to_string()),
|
||||
);
|
||||
get_tokenizer_config(&api.repo(repo))
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
"Could not retrieve tokenizer config from the Hugging Face hub."
|
||||
);
|
||||
HubTokenizerConfig::default()
|
||||
})
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
}
|
||||
let tokenizer_filename = cache_repo.get("tokenizer.json");
|
||||
let config_filename = cache_repo.get("config.json");
|
||||
let tokenizer_config_filename = cache_repo.get("tokenizer_config.json");
|
||||
let model_info = None;
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
};
|
||||
let tokenizer: Option<Tokenizer> =
|
||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
});
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
{
|
||||
HubTokenizerConfig::from_file(filename)
|
||||
} else {
|
||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
};
|
||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
if tokenizer.is_none() {
|
||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||
@ -483,7 +494,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
||||
}
|
||||
|
||||
/// get base tokenizer
|
||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
|
||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||
|
||||
// Open the file in read-only mode with buffer.
|
||||
@ -500,8 +511,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
||||
"main".to_string(),
|
||||
));
|
||||
|
||||
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
|
||||
Tokenizer::from_file(tokenizer_filename).ok()
|
||||
api_base_repo.get("tokenizer.json").await.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user