mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: support local configs and prefer hf hub
This commit is contained in:
parent
3513bc73b2
commit
fb6c220dc8
@ -29,12 +29,19 @@ pub struct HubModelInfo {
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
#[serde(default)]
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
impl HubTokenizerConfig {
|
||||
pub fn from_file(filename: &str) -> Self {
|
||||
let content = std::fs::read_to_string(filename).unwrap();
|
||||
serde_json::from_str(&content).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
|
@ -12,7 +12,6 @@ use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
use thiserror::Error;
|
||||
@ -55,6 +54,8 @@ struct Args {
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
@ -92,6 +93,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
json_output,
|
||||
@ -149,37 +151,14 @@ async fn main() -> Result<(), RouterError> {
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
let local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
let tokenizer_config: HubTokenizerConfig = match local_model {
|
||||
true => HubTokenizerConfig {
|
||||
chat_template: None,
|
||||
},
|
||||
false => get_tokenizer_config(
|
||||
&tokenizer_name,
|
||||
revision.as_deref(),
|
||||
authorization_token.as_deref(),
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
||||
HubTokenizerConfig {
|
||||
chat_template: None,
|
||||
}
|
||||
}),
|
||||
};
|
||||
// Load tokenizer config
|
||||
// This will be used to format the chat template
|
||||
let local_tokenizer_config_path =
|
||||
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
|
||||
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
|
||||
|
||||
let (tokenizer, model_info) = if local_model {
|
||||
// Get Model info
|
||||
let model_info = HubModelInfo {
|
||||
model_id: tokenizer_name.clone(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
|
||||
// Load local tokenizer
|
||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||
|
||||
(tokenizer, model_info)
|
||||
} else {
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
let mut builder = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_token(authorization_token);
|
||||
@ -188,19 +167,42 @@ async fn main() -> Result<(), RouterError> {
|
||||
builder = builder.with_cache_dir(cache_dir.into());
|
||||
}
|
||||
|
||||
if revision.is_none() {
|
||||
tracing::warn!("`--revision` is not set");
|
||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||
}
|
||||
builder
|
||||
};
|
||||
|
||||
let api = builder.build().unwrap();
|
||||
// Decide if we need to use the API based on the revision and local path
|
||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||
|
||||
// Initialize API if needed
|
||||
let api = if use_api {
|
||||
tracing::info!("Using the Hugging Face API");
|
||||
Some(api_builder().build().unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (tokenizer, model_info) = if local_model {
|
||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||
let model_info = HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
|
||||
(tokenizer, model_info)
|
||||
} else if let Some(api) = api.clone() {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.clone(),
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or("main".to_string()),
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
|
||||
// Get Model info
|
||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
|
||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo {
|
||||
@ -210,12 +212,30 @@ async fn main() -> Result<(), RouterError> {
|
||||
}
|
||||
});
|
||||
|
||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
(tokenizer, model_info)
|
||||
} else {
|
||||
// No API and no local model
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"No local model found and no revision specified".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
(tokenizer, model_info)
|
||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||
let tokenizer_config = if local_tokenizer_config {
|
||||
tracing::info!("Using local tokenizer config");
|
||||
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
|
||||
} else if let Some(api) = api {
|
||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||
get_tokenizer_config(&api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.unwrap_or_else(|| "main".to_string()),
|
||||
)))
|
||||
.await
|
||||
.unwrap()
|
||||
} else {
|
||||
tracing::warn!("Could not find tokenizer config locally and no revision specified");
|
||||
HubTokenizerConfig::default()
|
||||
};
|
||||
|
||||
if tokenizer.is_none() {
|
||||
@ -421,38 +441,17 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
||||
}
|
||||
|
||||
/// get tokenizer_config from the Huggingface Hub
|
||||
pub async fn get_tokenizer_config(
|
||||
model_id: &str,
|
||||
revision: Option<&str>,
|
||||
token: Option<&str>,
|
||||
) -> Option<HubTokenizerConfig> {
|
||||
let revision = match revision {
|
||||
None => {
|
||||
tracing::warn!("`--revision` is not set");
|
||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||
"main".to_string()
|
||||
}
|
||||
Some(revision) => revision.to_string(),
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
// Poor man's urlencode
|
||||
let revision = revision.replace('/', "%2F");
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/raw/{}/tokenizer_config.json",
|
||||
model_id, revision
|
||||
);
|
||||
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
||||
if let Some(token) = token {
|
||||
builder = builder.bearer_auth(token);
|
||||
}
|
||||
let response = builder.send().await.ok()?;
|
||||
if response.status().is_success() {
|
||||
let text = response.text().await.ok()?;
|
||||
let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?;
|
||||
Some(hub_tokenizer_config)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||
|
||||
// Open the file in read-only mode with buffer.
|
||||
let file = File::open(tokenizer_config_filename).ok()?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?;
|
||||
|
||||
Some(tokenizer_config)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
Loading…
Reference in New Issue
Block a user