mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Adding support for HF_HUB_OFFLINE
support in the router.
This commit is contained in:
parent
bfddfa5955
commit
af24703708
@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
pub fn from_file(filename: &std::path::Path) -> Self {
|
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||||
let content = std::fs::read_to_string(filename).unwrap();
|
let content = std::fs::read_to_string(filename).ok()?;
|
||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::config::Config;
|
use text_generation_router::config::Config;
|
||||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||||
@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
// Shared API builder initialization
|
||||||
let api_builder = || {
|
let api_builder = || {
|
||||||
@ -181,46 +180,94 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
// Initialize API if needed
|
// Initialize API if needed
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum Type {
|
||||||
|
Api(Api),
|
||||||
|
Cache(Cache),
|
||||||
|
None,
|
||||||
|
}
|
||||||
let api = if use_api {
|
let api = if use_api {
|
||||||
tracing::info!("Using the Hugging Face API");
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||||
|
let cache = Cache::default();
|
||||||
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
|
Type::Cache(cache)
|
||||||
|
} else {
|
||||||
match api_builder().build() {
|
match api_builder().build() {
|
||||||
Ok(api) => Some(api),
|
Ok(api) => Type::Api(api),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
None
|
Type::None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
Type::None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let (tokenizer, model_info, config) = if local_model {
|
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
||||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
Type::None => {
|
||||||
let model_info = HubModelInfo {
|
let tokenizer_filename = Some(local_path.join("tokenizer.json"));
|
||||||
model_id: tokenizer_name.to_string(),
|
let config_filename = Some(local_path.join("config.json"));
|
||||||
sha: None,
|
let tokenizer_config_filename = Some(local_path.join("tokenizer_config.json"));
|
||||||
pipeline_tag: None,
|
let model_info = None;
|
||||||
};
|
(
|
||||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
tokenizer_filename,
|
||||||
.ok()
|
config_filename,
|
||||||
.as_ref()
|
tokenizer_config_filename,
|
||||||
.and_then(|c| serde_json::from_str(c).ok());
|
model_info,
|
||||||
|
)
|
||||||
(tokenizer, model_info, config)
|
}
|
||||||
} else if let Some(api) = api.clone() {
|
Type::Api(api) => {
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
tokenizer_name.to_string(),
|
tokenizer_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
};
|
};
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
|
||||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||||
|
Some(model_info)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
(
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
model_info,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Type::Cache(cache) => {
|
||||||
|
let cache_repo = cache.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = cache_repo.get("tokenizer.json");
|
||||||
|
let config_filename = cache_repo.get("config.json");
|
||||||
|
let tokenizer_config_filename = cache_repo.get("tokenizer_config.json");
|
||||||
|
let model_info = None;
|
||||||
|
(
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
model_info,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tokenizer: Option<Tokenizer> =
|
||||||
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||||
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
std::fs::read_to_string(filename)
|
std::fs::read_to_string(filename)
|
||||||
.ok()
|
.ok()
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@ -232,61 +279,25 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
config.ok()
|
config.ok()
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
model_id: tokenizer_name.to_string(),
|
||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
(tokenizer, model_info, config)
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
} else {
|
} else {
|
||||||
// No API and no local model
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"No local model found and no revision specified".to_string(),
|
|
||||||
));
|
|
||||||
};
|
};
|
||||||
|
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
|
|
||||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
|
||||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
|
||||||
tracing::info!(
|
|
||||||
"Using local tokenizer config from user specified path {}",
|
|
||||||
path
|
|
||||||
);
|
|
||||||
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
|
||||||
} else if local_model {
|
|
||||||
tracing::info!("Using local tokenizer config");
|
|
||||||
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
|
||||||
} else {
|
|
||||||
match api {
|
|
||||||
Some(api) => {
|
|
||||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
|
||||||
let repo = Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.unwrap_or("main".to_string()),
|
|
||||||
);
|
|
||||||
get_tokenizer_config(&api.repo(repo))
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
tracing::warn!(
|
|
||||||
"Could not retrieve tokenizer config from the Hugging Face hub."
|
|
||||||
);
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
}
|
});
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
tracing::info!("Using config {config:?}");
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
@ -483,7 +494,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// get base tokenizer
|
/// get base tokenizer
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
|
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
// Open the file in read-only mode with buffer.
|
||||||
@ -500,8 +511,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
|||||||
"main".to_string(),
|
"main".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
|
api_base_repo.get("tokenizer.json").await.ok()
|
||||||
Tokenizer::from_file(tokenizer_filename).ok()
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user