mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: support local configs and prefer hf hub
This commit is contained in:
parent
3513bc73b2
commit
fb6c220dc8
@ -29,12 +29,19 @@ pub struct HubModelInfo {
|
|||||||
pub pipeline_tag: Option<String>,
|
pub pipeline_tag: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl HubTokenizerConfig {
|
||||||
|
pub fn from_file(filename: &str) -> Self {
|
||||||
|
let content = std::fs::read_to_string(filename).unwrap();
|
||||||
|
serde_json::from_str(&content).unwrap_or_default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
|
@ -12,7 +12,6 @@ use std::fs::File;
|
|||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::time::Duration;
|
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -55,6 +54,8 @@ struct Args {
|
|||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
@ -92,6 +93,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
json_output,
|
json_output,
|
||||||
@ -149,37 +151,14 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
let local_model = local_path.exists() && local_path.is_dir();
|
||||||
|
|
||||||
let tokenizer_config: HubTokenizerConfig = match local_model {
|
// Load tokenizer config
|
||||||
true => HubTokenizerConfig {
|
// This will be used to format the chat template
|
||||||
chat_template: None,
|
let local_tokenizer_config_path =
|
||||||
},
|
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
|
||||||
false => get_tokenizer_config(
|
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
|
||||||
&tokenizer_name,
|
|
||||||
revision.as_deref(),
|
|
||||||
authorization_token.as_deref(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
|
||||||
HubTokenizerConfig {
|
|
||||||
chat_template: None,
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (tokenizer, model_info) = if local_model {
|
// Shared API builder initialization
|
||||||
// Get Model info
|
let api_builder = || {
|
||||||
let model_info = HubModelInfo {
|
|
||||||
model_id: tokenizer_name.clone(),
|
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load local tokenizer
|
|
||||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
|
||||||
|
|
||||||
(tokenizer, model_info)
|
|
||||||
} else {
|
|
||||||
let mut builder = ApiBuilder::new()
|
let mut builder = ApiBuilder::new()
|
||||||
.with_progress(false)
|
.with_progress(false)
|
||||||
.with_token(authorization_token);
|
.with_token(authorization_token);
|
||||||
@ -188,19 +167,42 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
builder = builder.with_cache_dir(cache_dir.into());
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
if revision.is_none() {
|
builder
|
||||||
tracing::warn!("`--revision` is not set");
|
};
|
||||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
|
||||||
}
|
|
||||||
|
|
||||||
let api = builder.build().unwrap();
|
// Decide if we need to use the API based on the revision and local path
|
||||||
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
|
// Initialize API if needed
|
||||||
|
let api = if use_api {
|
||||||
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
Some(api_builder().build().unwrap())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load tokenizer and model info
|
||||||
|
let (tokenizer, model_info) = if local_model {
|
||||||
|
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||||
|
let model_info = HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
(tokenizer, model_info)
|
||||||
|
} else if let Some(api) = api.clone() {
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
tokenizer_name.clone(),
|
tokenizer_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
revision.clone().unwrap_or("main".to_string()),
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Get Model info
|
let tokenizer = match api_repo.get("tokenizer.json").await {
|
||||||
|
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
||||||
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
|
};
|
||||||
|
|
||||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
HubModelInfo {
|
HubModelInfo {
|
||||||
@ -210,12 +212,30 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
|
|
||||||
(tokenizer, model_info)
|
(tokenizer, model_info)
|
||||||
|
} else {
|
||||||
|
// No API and no local model
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"No local model found and no revision specified".to_string(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||||
|
let tokenizer_config = if local_tokenizer_config {
|
||||||
|
tracing::info!("Using local tokenizer config");
|
||||||
|
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
|
||||||
|
} else if let Some(api) = api {
|
||||||
|
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||||
|
get_tokenizer_config(&api.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.unwrap_or_else(|| "main".to_string()),
|
||||||
|
)))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no revision specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
@ -421,38 +441,17 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// get tokenizer_config from the Huggingface Hub
|
/// get tokenizer_config from the Huggingface Hub
|
||||||
pub async fn get_tokenizer_config(
|
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||||
model_id: &str,
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||||
revision: Option<&str>,
|
|
||||||
token: Option<&str>,
|
// Open the file in read-only mode with buffer.
|
||||||
) -> Option<HubTokenizerConfig> {
|
let file = File::open(tokenizer_config_filename).ok()?;
|
||||||
let revision = match revision {
|
let reader = BufReader::new(file);
|
||||||
None => {
|
|
||||||
tracing::warn!("`--revision` is not set");
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?;
|
||||||
"main".to_string()
|
|
||||||
}
|
Some(tokenizer_config)
|
||||||
Some(revision) => revision.to_string(),
|
|
||||||
};
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
// Poor man's urlencode
|
|
||||||
let revision = revision.replace('/', "%2F");
|
|
||||||
let url = format!(
|
|
||||||
"https://huggingface.co/{}/raw/{}/tokenizer_config.json",
|
|
||||||
model_id, revision
|
|
||||||
);
|
|
||||||
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
|
||||||
if let Some(token) = token {
|
|
||||||
builder = builder.bearer_auth(token);
|
|
||||||
}
|
|
||||||
let response = builder.send().await.ok()?;
|
|
||||||
if response.status().is_success() {
|
|
||||||
let text = response.text().await.ok()?;
|
|
||||||
let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?;
|
|
||||||
Some(hub_tokenizer_config)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
Loading…
Reference in New Issue
Block a user