feat: support local configs and prefer hf hub

This commit is contained in:
drbh 2024-01-15 08:58:11 -05:00
parent 3513bc73b2
commit fb6c220dc8
2 changed files with 83 additions and 77 deletions

View File

@ -29,12 +29,19 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>,
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Deserialize, Default)]
pub struct HubTokenizerConfig {
#[serde(default)]
pub chat_template: Option<String>,
}
impl HubTokenizerConfig {
pub fn from_file(filename: &str) -> Self {
let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default()
}
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
/// Model info

View File

@ -12,7 +12,6 @@ use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use std::time::Duration;
use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
use thiserror::Error;
@ -55,6 +54,8 @@ struct Args {
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
@ -92,6 +93,7 @@ async fn main() -> Result<(), RouterError> {
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
json_output,
@ -149,37 +151,14 @@ async fn main() -> Result<(), RouterError> {
let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
let tokenizer_config: HubTokenizerConfig = match local_model {
true => HubTokenizerConfig {
chat_template: None,
},
false => get_tokenizer_config(
&tokenizer_name,
revision.as_deref(),
authorization_token.as_deref(),
)
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
HubTokenizerConfig {
chat_template: None,
}
}),
};
// Load tokenizer config
// This will be used to format the chat template
let local_tokenizer_config_path =
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
let (tokenizer, model_info) = if local_model {
// Get Model info
let model_info = HubModelInfo {
model_id: tokenizer_name.clone(),
sha: None,
pipeline_tag: None,
};
// Load local tokenizer
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
(tokenizer, model_info)
} else {
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
@ -188,19 +167,42 @@ async fn main() -> Result<(), RouterError> {
builder = builder.with_cache_dir(cache_dir.into());
}
if revision.is_none() {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
}
builder
};
let api = builder.build().unwrap();
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
let api = if use_api {
tracing::info!("Using the Hugging Face API");
Some(api_builder().build().unwrap())
} else {
None
};
// Load tokenizer and model info
let (tokenizer, model_info) = if local_model {
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
let model_info = HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
};
(tokenizer, model_info)
} else if let Some(api) = api.clone() {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.clone(),
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or("main".to_string()),
revision.clone().unwrap_or_else(|| "main".to_string()),
));
// Get Model info
let tokenizer = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
@ -210,12 +212,30 @@ async fn main() -> Result<(), RouterError> {
}
});
let tokenizer = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
(tokenizer, model_info)
} else {
// No API and no local model
return Err(RouterError::ArgumentValidation(
"No local model found and no revision specified".to_string(),
));
};
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if local_tokenizer_config {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
} else if let Some(api) = api {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
get_tokenizer_config(&api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main".to_string()),
)))
.await
.unwrap()
} else {
tracing::warn!("Could not find tokenizer config locally and no revision specified");
HubTokenizerConfig::default()
};
if tokenizer.is_none() {
@ -421,38 +441,17 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(
model_id: &str,
revision: Option<&str>,
token: Option<&str>,
) -> Option<HubTokenizerConfig> {
let revision = match revision {
None => {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
"main".to_string()
}
Some(revision) => revision.to_string(),
};
let client = reqwest::Client::new();
// Poor man's urlencode
let revision = revision.replace('/', "%2F");
let url = format!(
"https://huggingface.co/{}/raw/{}/tokenizer_config.json",
model_id, revision
);
let mut builder = client.get(url).timeout(Duration::from_secs(5));
if let Some(token) = token {
builder = builder.bearer_auth(token);
}
let response = builder.send().await.ok()?;
if response.status().is_success() {
let text = response.text().await.ok()?;
let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?;
Some(hub_tokenizer_config)
} else {
None
}
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?;
Some(tokenizer_config)
}
#[derive(Debug, Error)]