fix: re-add changes removed during rebase

This commit is contained in:
drbh 2024-01-10 14:00:38 -05:00
parent 55455a16c7
commit d009aa3ee3

View File

@ -1,19 +1,23 @@
/// Text Generation Inference webserver entrypoint /// Text Generation Inference webserver entrypoint
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource; use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue}; use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::time::Duration; use std::time::Duration;
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
use thiserror::Error; use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -69,7 +73,8 @@ struct Args {
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
} }
fn main() -> Result<(), RouterError> { #[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
@ -98,6 +103,9 @@ fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
} = args; } = args;
// Launch Tokio runtime
init_logging(otlp_endpoint, json_output);
// Validate args // Validate args
if max_input_length >= max_total_tokens { if max_input_length >= max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
@ -141,161 +149,177 @@ fn main() -> Result<(), RouterError> {
// This will only be used to validate payloads // This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir(); let local_model = local_path.exists() && local_path.is_dir();
let tokenizer = if local_model {
// Load local tokenizer let tokenizer_config: HubTokenizerConfig = match local_model {
Tokenizer::from_file(local_path.join("tokenizer.json")).ok() true => HubTokenizerConfig {
} else { chat_template: None,
// Download and instantiate tokenizer },
// We need to download it outside of the Tokio runtime false => get_tokenizer_config(
let params = FromPretrainedParameters { &tokenizer_name,
revision: revision.clone().unwrap_or("main".to_string()), revision.as_deref(),
auth_token: authorization_token.clone(), authorization_token.as_deref(),
..Default::default() )
}; .await
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() .unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
HubTokenizerConfig {
chat_template: None,
}
}),
}; };
// Launch Tokio runtime let (tokenizer, model_info) = if local_model {
tokio::runtime::Builder::new_multi_thread() // Get Model info
.enable_all() let model_info = HubModelInfo {
.build()? model_id: tokenizer_name.clone(),
.block_on(async { sha: None,
init_logging(otlp_endpoint, json_output); pipeline_tag: None,
};
if tokenizer.is_none() { // Load local tokenizer
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
(tokenizer, model_info)
} else {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() {
builder = builder.with_cache_dir(cache_dir.into());
}
if revision.is_none() {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
}
let api = builder.build().unwrap();
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.clone(),
RepoType::Model,
revision.clone().unwrap_or("main".to_string()),
));
// Get Model info
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
});
let tokenizer = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
(tokenizer, model_info)
};
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(RouterError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
)
.await
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!( tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}" "`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
); );
tracing::warn!("Rust input length validation and truncation is disabled"); tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
} }
// Get Model info max_supported_batch_total_tokens
let model_info = match local_model { }
true => HubModelInfo { };
model_id: tokenizer_name.clone(), tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
sha: None, tracing::info!("Connected");
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, revision.as_deref(), authorization_token.as_deref())
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
}),
};
let tokenizer_config: HubTokenizerConfig = match local_model { let addr = match hostname.parse() {
true => HubTokenizerConfig{ Ok(ip) => SocketAddr::new(ip, port),
chat_template: None, Err(_) => {
}, tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
false => get_tokenizer_config(&tokenizer_name, revision.as_deref(), authorization_token.as_deref()) SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
.await.unwrap_or_else(|| { }
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); };
HubTokenizerConfig{
chat_template: None,
}
}),
};
// Run server
// if pipeline-tag == text-generation we default to return_full_text = true server::run(
let compat_return_full_text = match &model_info.pipeline_tag { model_info,
None => { shard_info,
tracing::warn!("no pipeline tag found for model {tokenizer_name}"); compat_return_full_text,
false max_concurrent_requests,
} max_best_of,
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", max_stop_sequences,
}; max_top_n_tokens,
max_input_length,
// Instantiate sharded client from the master unix socket max_total_tokens,
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) waiting_served_ratio,
.await max_batch_prefill_tokens,
.map_err(RouterError::Connection)?; max_supported_batch_total_tokens,
// Clear the cache; useful if the webserver rebooted max_waiting_tokens,
sharded_client sharded_client,
.clear_cache(None) tokenizer,
.await validation_workers,
.map_err(RouterError::Cache)?; addr,
// Get info from the shard cors_allow_origin,
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; ngrok,
ngrok_authtoken,
// Warmup model ngrok_edge,
tracing::info!("Warming up model"); tokenizer_config,
let max_supported_batch_total_tokens = match sharded_client )
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) .await?;
.await Ok(())
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
}
max_supported_batch_total_tokens
}
};
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected");
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Run server
server::run(
model_info,
shard_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_supported_batch_total_tokens,
max_waiting_tokens,
sharded_client,
tokenizer,
validation_workers,
addr,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
)
.await?;
Ok(())
})
} }
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
@ -354,30 +378,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info( pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
model_id: &str, let response = api.info_request().send().await.ok()?;
revision: Option<&str>,
token: Option<&str>,
) -> Option<HubModelInfo> {
let revision = match revision {
None => {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
"main".to_string()
}
Some(revision) => revision.to_string(),
};
let client = reqwest::Client::new();
// Poor man's urlencode
let revision = revision.replace('/', "%2F");
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
let mut builder = client.get(url).timeout(Duration::from_secs(5));
if let Some(token) = token {
builder = builder.bearer_auth(token);
}
let response = builder.send().await.ok()?;
if response.status().is_success() { if response.status().is_success() {
let hub_model_info: HubModelInfo = let hub_model_info: HubModelInfo =
@ -394,6 +396,31 @@ pub async fn get_model_info(
} }
} }
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
Tokenizer::from_file(tokenizer_filename).ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub /// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config( pub async fn get_tokenizer_config(
model_id: &str, model_id: &str,