2023-02-13 12:02:45 +00:00
|
|
|
/// Text Generation Inference webserver entrypoint
|
2023-02-17 17:22:00 +00:00
|
|
|
use axum::http::HeaderValue;
|
2022-10-18 13:19:03 +00:00
|
|
|
use clap::Parser;
|
2023-02-13 12:02:45 +00:00
|
|
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
|
|
|
use opentelemetry::sdk::trace;
|
|
|
|
use opentelemetry::sdk::trace::Sampler;
|
|
|
|
use opentelemetry::sdk::Resource;
|
|
|
|
use opentelemetry::{global, KeyValue};
|
|
|
|
use opentelemetry_otlp::WithExportConfig;
|
2022-10-17 16:27:33 +00:00
|
|
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
2023-03-06 13:39:36 +00:00
|
|
|
use std::path::Path;
|
2022-10-28 17:24:00 +00:00
|
|
|
use text_generation_client::ShardedClient;
|
2023-04-18 14:16:06 +00:00
|
|
|
use text_generation_router::{server, ModelInfo};
|
|
|
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
2023-02-17 17:22:00 +00:00
|
|
|
use tower_http::cors::AllowOrigin;
|
2023-02-13 12:02:45 +00:00
|
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
|
|
use tracing_subscriber::util::SubscriberInitExt;
|
|
|
|
use tracing_subscriber::{EnvFilter, Layer};
|
2022-10-17 16:27:33 +00:00
|
|
|
|
|
|
|
/// App Configuration
|
|
|
|
#[derive(Parser, Debug)]
|
|
|
|
#[clap(author, version, about, long_about = None)]
|
|
|
|
struct Args {
|
2022-10-18 13:19:03 +00:00
|
|
|
#[clap(default_value = "128", long, env)]
|
|
|
|
max_concurrent_requests: usize,
|
2023-03-09 14:30:54 +00:00
|
|
|
#[clap(default_value = "2", long, env)]
|
|
|
|
max_best_of: usize,
|
2023-02-15 20:56:59 +00:00
|
|
|
#[clap(default_value = "4", long, env)]
|
|
|
|
max_stop_sequences: usize,
|
2022-10-18 13:19:03 +00:00
|
|
|
#[clap(default_value = "1000", long, env)]
|
|
|
|
max_input_length: usize,
|
2023-02-15 20:56:59 +00:00
|
|
|
#[clap(default_value = "1512", long, env)]
|
|
|
|
max_total_tokens: usize,
|
2022-10-18 13:19:03 +00:00
|
|
|
#[clap(default_value = "32", long, env)]
|
2022-10-17 16:27:33 +00:00
|
|
|
max_batch_size: usize,
|
2022-10-21 14:40:05 +00:00
|
|
|
#[clap(default_value = "20", long, env)]
|
|
|
|
max_waiting_tokens: usize,
|
2022-10-17 16:27:33 +00:00
|
|
|
#[clap(default_value = "3000", long, short, env)]
|
|
|
|
port: u16,
|
2023-04-09 18:22:27 +00:00
|
|
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
2022-10-18 13:19:03 +00:00
|
|
|
master_shard_uds_path: String,
|
2022-10-17 16:27:33 +00:00
|
|
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
|
|
|
tokenizer_name: String,
|
2023-04-18 14:16:06 +00:00
|
|
|
#[clap(default_value = "main", long, env)]
|
|
|
|
revision: String,
|
2022-10-18 13:19:03 +00:00
|
|
|
#[clap(default_value = "2", long, env)]
|
|
|
|
validation_workers: usize,
|
2022-11-02 16:29:56 +00:00
|
|
|
#[clap(long, env)]
|
|
|
|
json_output: bool,
|
2023-02-13 12:02:45 +00:00
|
|
|
#[clap(long, env)]
|
|
|
|
otlp_endpoint: Option<String>,
|
2023-02-17 17:22:00 +00:00
|
|
|
#[clap(long, env)]
|
|
|
|
cors_allow_origin: Option<Vec<String>>,
|
2022-10-17 16:27:33 +00:00
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-11 14:50:54 +00:00
|
|
|
fn main() -> Result<(), std::io::Error> {
|
2022-10-17 16:27:33 +00:00
|
|
|
// Get args
|
|
|
|
let args = Args::parse();
|
2022-10-18 13:19:03 +00:00
|
|
|
// Pattern match configuration
|
2022-10-17 16:27:33 +00:00
|
|
|
let Args {
|
2022-10-18 13:19:03 +00:00
|
|
|
max_concurrent_requests,
|
2023-03-09 14:30:54 +00:00
|
|
|
max_best_of,
|
2023-02-15 20:56:59 +00:00
|
|
|
max_stop_sequences,
|
2022-10-18 13:19:03 +00:00
|
|
|
max_input_length,
|
2023-02-15 20:56:59 +00:00
|
|
|
max_total_tokens,
|
2022-10-17 16:27:33 +00:00
|
|
|
max_batch_size,
|
2022-10-21 14:40:05 +00:00
|
|
|
max_waiting_tokens,
|
2022-10-17 16:27:33 +00:00
|
|
|
port,
|
2022-10-18 13:19:03 +00:00
|
|
|
master_shard_uds_path,
|
2022-10-17 16:27:33 +00:00
|
|
|
tokenizer_name,
|
2023-04-18 14:16:06 +00:00
|
|
|
revision,
|
2022-10-18 13:19:03 +00:00
|
|
|
validation_workers,
|
2022-11-02 16:29:56 +00:00
|
|
|
json_output,
|
2023-02-13 12:02:45 +00:00
|
|
|
otlp_endpoint,
|
2023-02-17 17:22:00 +00:00
|
|
|
cors_allow_origin,
|
2022-10-17 16:27:33 +00:00
|
|
|
} = args;
|
|
|
|
|
2022-11-02 16:29:56 +00:00
|
|
|
if validation_workers == 0 {
|
2022-10-18 13:19:03 +00:00
|
|
|
panic!("validation_workers must be > 0");
|
|
|
|
}
|
|
|
|
|
2023-02-17 17:22:00 +00:00
|
|
|
// CORS allowed origins
|
|
|
|
// map to go inside the option and then map to parse from String to HeaderValue
|
|
|
|
// Finally, convert to AllowOrigin
|
|
|
|
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
|
|
|
AllowOrigin::list(
|
|
|
|
cors_allow_origin
|
|
|
|
.iter()
|
|
|
|
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
|
|
|
)
|
|
|
|
});
|
|
|
|
|
2023-04-19 18:06:06 +00:00
|
|
|
// Parse Huggingface hub token
|
|
|
|
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
|
|
|
|
2023-03-06 13:39:36 +00:00
|
|
|
// Tokenizer instance
|
2022-10-18 13:19:03 +00:00
|
|
|
// This will only be used to validate payloads
|
2023-03-06 13:39:36 +00:00
|
|
|
let local_path = Path::new(&tokenizer_name);
|
2023-04-18 14:16:06 +00:00
|
|
|
let local_model = local_path.exists() && local_path.is_dir();
|
|
|
|
let tokenizer = if local_model {
|
|
|
|
// Load local tokenizer
|
|
|
|
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
|
|
|
} else {
|
|
|
|
// Download and instantiate tokenizer
|
|
|
|
// We need to download it outside of the Tokio runtime
|
|
|
|
let params = FromPretrainedParameters {
|
|
|
|
revision: revision.clone(),
|
2023-04-19 18:06:06 +00:00
|
|
|
auth_token: authorization_token.clone(),
|
2023-04-18 14:16:06 +00:00
|
|
|
..Default::default()
|
2023-03-06 13:39:36 +00:00
|
|
|
};
|
2023-04-18 14:16:06 +00:00
|
|
|
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
|
|
|
|
};
|
2022-10-11 14:50:54 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Launch Tokio runtime
|
2022-10-11 14:50:54 +00:00
|
|
|
tokio::runtime::Builder::new_multi_thread()
|
|
|
|
.enable_all()
|
|
|
|
.build()
|
|
|
|
.unwrap()
|
|
|
|
.block_on(async {
|
2023-03-30 15:28:14 +00:00
|
|
|
init_logging(otlp_endpoint, json_output);
|
|
|
|
|
2023-04-09 18:22:27 +00:00
|
|
|
if tokenizer.is_none() {
|
|
|
|
tracing::warn!(
|
|
|
|
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
|
|
|
);
|
|
|
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
|
|
|
}
|
|
|
|
|
2023-04-18 14:16:06 +00:00
|
|
|
// Get Model info
|
|
|
|
let model_info = match local_model {
|
|
|
|
true => ModelInfo {
|
|
|
|
model_id: tokenizer_name.clone(),
|
|
|
|
sha: None,
|
|
|
|
pipeline_tag: None,
|
|
|
|
},
|
2023-04-19 18:06:06 +00:00
|
|
|
false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
|
2023-04-18 14:16:06 +00:00
|
|
|
};
|
2023-02-28 09:19:32 +00:00
|
|
|
|
|
|
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
2023-04-18 14:16:06 +00:00
|
|
|
let compat_return_full_text = match &model_info.pipeline_tag {
|
2023-02-28 09:19:32 +00:00
|
|
|
None => {
|
|
|
|
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
|
|
|
false
|
|
|
|
}
|
2023-04-18 14:16:06 +00:00
|
|
|
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
2023-02-28 09:19:32 +00:00
|
|
|
};
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Instantiate sharded client from the master unix socket
|
2022-10-22 21:40:05 +00:00
|
|
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
2022-10-17 12:59:00 +00:00
|
|
|
.await
|
|
|
|
.expect("Could not connect to server");
|
2022-10-18 13:19:03 +00:00
|
|
|
// Clear the cache; useful if the webserver rebooted
|
2022-10-11 14:50:54 +00:00
|
|
|
sharded_client
|
2023-03-28 09:29:35 +00:00
|
|
|
.clear_cache(None)
|
2022-10-11 14:50:54 +00:00
|
|
|
.await
|
|
|
|
.expect("Unable to clear cache");
|
|
|
|
tracing::info!("Connected");
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Binds on localhost
|
2022-10-17 16:27:33 +00:00
|
|
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Run server
|
|
|
|
server::run(
|
2023-04-18 14:16:06 +00:00
|
|
|
model_info,
|
2023-02-28 09:19:32 +00:00
|
|
|
compat_return_full_text,
|
2022-10-18 13:19:03 +00:00
|
|
|
max_concurrent_requests,
|
2023-03-09 14:30:54 +00:00
|
|
|
max_best_of,
|
2023-02-15 20:56:59 +00:00
|
|
|
max_stop_sequences,
|
2022-10-18 13:19:03 +00:00
|
|
|
max_input_length,
|
2023-02-15 20:56:59 +00:00
|
|
|
max_total_tokens,
|
2022-10-18 13:19:03 +00:00
|
|
|
max_batch_size,
|
2022-10-21 14:40:05 +00:00
|
|
|
max_waiting_tokens,
|
2022-10-18 13:19:03 +00:00
|
|
|
sharded_client,
|
|
|
|
tokenizer,
|
|
|
|
validation_workers,
|
|
|
|
addr,
|
2023-02-17 17:22:00 +00:00
|
|
|
cors_allow_origin,
|
2022-10-18 13:19:03 +00:00
|
|
|
)
|
|
|
|
.await;
|
2022-10-11 16:14:39 +00:00
|
|
|
Ok(())
|
2022-10-11 14:50:54 +00:00
|
|
|
})
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
2023-02-13 12:02:45 +00:00
|
|
|
|
|
|
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
|
|
|
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
|
|
|
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
|
|
|
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
|
|
|
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|
|
|
let mut layers = Vec::new();
|
|
|
|
|
|
|
|
// STDOUT/STDERR layer
|
|
|
|
let fmt_layer = tracing_subscriber::fmt::layer()
|
|
|
|
.with_file(true)
|
|
|
|
.with_line_number(true);
|
|
|
|
|
|
|
|
let fmt_layer = match json_output {
|
|
|
|
true => fmt_layer.json().flatten_event(true).boxed(),
|
|
|
|
false => fmt_layer.boxed(),
|
|
|
|
};
|
|
|
|
layers.push(fmt_layer);
|
|
|
|
|
|
|
|
// OpenTelemetry tracing layer
|
|
|
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
|
|
|
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
|
|
|
|
|
|
let tracer = opentelemetry_otlp::new_pipeline()
|
|
|
|
.tracing()
|
|
|
|
.with_exporter(
|
|
|
|
opentelemetry_otlp::new_exporter()
|
|
|
|
.tonic()
|
|
|
|
.with_endpoint(otlp_endpoint),
|
|
|
|
)
|
|
|
|
.with_trace_config(
|
|
|
|
trace::config()
|
|
|
|
.with_resource(Resource::new(vec![KeyValue::new(
|
|
|
|
"service.name",
|
|
|
|
"text-generation-inference.router",
|
|
|
|
)]))
|
|
|
|
.with_sampler(Sampler::AlwaysOn),
|
|
|
|
)
|
|
|
|
.install_batch(opentelemetry::runtime::Tokio);
|
|
|
|
|
|
|
|
if let Ok(tracer) = tracer {
|
|
|
|
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
|
|
|
axum_tracing_opentelemetry::init_propagator().unwrap();
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
// Filter events with LOG_LEVEL
|
|
|
|
let env_filter =
|
|
|
|
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
|
|
|
|
|
|
|
tracing_subscriber::registry()
|
|
|
|
.with(env_filter)
|
|
|
|
.with(layers)
|
|
|
|
.init();
|
|
|
|
}
|
2023-04-18 14:16:06 +00:00
|
|
|
|
|
|
|
/// get model info from the Huggingface Hub
|
2023-04-19 18:06:06 +00:00
|
|
|
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> ModelInfo {
|
|
|
|
let client = reqwest::Client::new();
|
|
|
|
let mut builder = client.get(format!(
|
2023-04-18 14:16:06 +00:00
|
|
|
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
|
2023-04-19 18:06:06 +00:00
|
|
|
));
|
|
|
|
if let Some(token) = token {
|
|
|
|
builder = builder.bearer_auth(token);
|
|
|
|
}
|
|
|
|
|
|
|
|
let model_info = builder
|
|
|
|
.send()
|
|
|
|
.await
|
|
|
|
.expect("Could not connect to hf.co")
|
|
|
|
.text()
|
|
|
|
.await
|
|
|
|
.expect("error when retrieving model info from hf.co");
|
2023-04-18 14:16:06 +00:00
|
|
|
serde_json::from_str(&model_info).expect("unable to parse model info")
|
|
|
|
}
|