mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
feat: experimental python packaging and interface
This commit is contained in:
parent
612bc483b6
commit
0e5220d704
102
Cargo.lock
generated
102
Cargo.lock
generated
@ -1783,6 +1783,15 @@ version = "2.7.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d"
|
checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memoffset"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "metrics"
|
name = "metrics"
|
||||||
version = "0.21.1"
|
version = "0.21.1"
|
||||||
@ -2665,6 +2674,82 @@ dependencies = [
|
|||||||
"prost 0.12.6",
|
"prost 0.12.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3"
|
||||||
|
version = "0.20.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"indoc",
|
||||||
|
"libc",
|
||||||
|
"memoffset",
|
||||||
|
"parking_lot",
|
||||||
|
"portable-atomic",
|
||||||
|
"pyo3-build-config",
|
||||||
|
"pyo3-ffi",
|
||||||
|
"pyo3-macros",
|
||||||
|
"unindent",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-asyncio"
|
||||||
|
version = "0.20.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6ea6b68e93db3622f3bb3bf363246cf948ed5375afe7abff98ccbdd50b184995"
|
||||||
|
dependencies = [
|
||||||
|
"futures",
|
||||||
|
"once_cell",
|
||||||
|
"pin-project-lite",
|
||||||
|
"pyo3",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-build-config"
|
||||||
|
version = "0.20.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7"
|
||||||
|
dependencies = [
|
||||||
|
"once_cell",
|
||||||
|
"target-lexicon",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-ffi"
|
||||||
|
version = "0.20.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"pyo3-build-config",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-macros"
|
||||||
|
version = "0.20.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"pyo3-macros-backend",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-macros-backend"
|
||||||
|
version = "0.20.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185"
|
||||||
|
dependencies = [
|
||||||
|
"heck 0.4.1",
|
||||||
|
"proc-macro2",
|
||||||
|
"pyo3-build-config",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "qoi"
|
name = "qoi"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
@ -3627,6 +3712,17 @@ dependencies = [
|
|||||||
"vergen",
|
"vergen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tgi"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"pyo3",
|
||||||
|
"pyo3-asyncio",
|
||||||
|
"text-generation-launcher",
|
||||||
|
"text-generation-router",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.61"
|
version = "1.0.61"
|
||||||
@ -4186,6 +4282,12 @@ version = "0.1.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unindent"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "untrusted"
|
name = "untrusted"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
@ -4,7 +4,8 @@ members = [
|
|||||||
"router",
|
"router",
|
||||||
"router/client",
|
"router/client",
|
||||||
"router/grpc-metadata",
|
"router/grpc-metadata",
|
||||||
"launcher"
|
"launcher",
|
||||||
|
"tgi"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
|
2026
launcher/src/lib.rs
Normal file
2026
launcher/src/lib.rs
Normal file
File diff suppressed because it is too large
Load Diff
1627
launcher/src/main.rs
1627
launcher/src/main.rs
File diff suppressed because it is too large
Load Diff
@ -6,15 +6,491 @@ mod queue;
|
|||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
|
use axum::http::HeaderValue;
|
||||||
|
use config::Config;
|
||||||
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
use infer::{Infer, InferError, InferStreamResponse};
|
use infer::{Infer, InferError, InferStreamResponse};
|
||||||
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
|
use opentelemetry::sdk::trace;
|
||||||
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
|
use opentelemetry::sdk::Resource;
|
||||||
|
use opentelemetry::{global, KeyValue};
|
||||||
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::OwnedSemaphorePermit;
|
use tokio::sync::OwnedSemaphorePermit;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
|
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||||
|
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||||
|
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||||
|
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||||
|
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||||
|
let mut layers = Vec::new();
|
||||||
|
|
||||||
|
// STDOUT/STDERR layer
|
||||||
|
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||||
|
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||||
|
.with_file(true)
|
||||||
|
.with_ansi(ansi)
|
||||||
|
.with_line_number(true);
|
||||||
|
|
||||||
|
let fmt_layer = match json_output {
|
||||||
|
true => fmt_layer.json().flatten_event(true).boxed(),
|
||||||
|
false => fmt_layer.boxed(),
|
||||||
|
};
|
||||||
|
layers.push(fmt_layer);
|
||||||
|
|
||||||
|
// OpenTelemetry tracing layer
|
||||||
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
|
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||||
|
|
||||||
|
let tracer = opentelemetry_otlp::new_pipeline()
|
||||||
|
.tracing()
|
||||||
|
.with_exporter(
|
||||||
|
opentelemetry_otlp::new_exporter()
|
||||||
|
.tonic()
|
||||||
|
.with_endpoint(otlp_endpoint),
|
||||||
|
)
|
||||||
|
.with_trace_config(
|
||||||
|
trace::config()
|
||||||
|
.with_resource(Resource::new(vec![KeyValue::new(
|
||||||
|
"service.name",
|
||||||
|
"text-generation-inference.router",
|
||||||
|
)]))
|
||||||
|
.with_sampler(Sampler::AlwaysOn),
|
||||||
|
)
|
||||||
|
.install_batch(opentelemetry::runtime::Tokio);
|
||||||
|
|
||||||
|
if let Ok(tracer) = tracer {
|
||||||
|
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||||
|
init_tracing_opentelemetry::init_propagator().unwrap();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter events with LOG_LEVEL
|
||||||
|
let env_filter =
|
||||||
|
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(env_filter)
|
||||||
|
.with(layers)
|
||||||
|
.init();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// get model info from the Huggingface Hub
|
||||||
|
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
||||||
|
let response = api.info_request().send().await.ok()?;
|
||||||
|
|
||||||
|
if response.status().is_success() {
|
||||||
|
let hub_model_info: HubModelInfo =
|
||||||
|
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
||||||
|
if let Some(sha) = &hub_model_info.sha {
|
||||||
|
tracing::info!(
|
||||||
|
"Serving revision {sha} of model {}",
|
||||||
|
hub_model_info.model_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Some(hub_model_info)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// get base tokenizer
|
||||||
|
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||||
|
|
||||||
|
// Open the file in read-only mode with buffer.
|
||||||
|
let file = File::open(config_filename).ok()?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of `User`.
|
||||||
|
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
||||||
|
|
||||||
|
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
||||||
|
let api_base_repo = api.repo(Repo::with_revision(
|
||||||
|
base_model_id.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
"main".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
api_base_repo.get("tokenizer.json").await.ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// get tokenizer_config from the Huggingface Hub
|
||||||
|
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||||
|
|
||||||
|
// Open the file in read-only mode with buffer.
|
||||||
|
let file = File::open(tokenizer_config_filename).ok()?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
||||||
|
e
|
||||||
|
})
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
Some(tokenizer_config)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum RouterError {
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
#[error("Axum webserver failed: {0}")]
|
||||||
|
Axum(#[from] axum::BoxError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn internal_main(
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
hostname: String,
|
||||||
|
port: u16,
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
tokenizer_name: String,
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
revision: Option<String>,
|
||||||
|
validation_workers: usize,
|
||||||
|
json_output: bool,
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
ngrok: bool,
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
) -> Result<(), RouterError> {
|
||||||
|
// Launch Tokio runtime
|
||||||
|
if otlp_endpoint.is_some() {
|
||||||
|
// Initialize if OpenTelemetry is enabled
|
||||||
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation_workers == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORS allowed origins
|
||||||
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
|
// Finally, convert to AllowOrigin
|
||||||
|
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
||||||
|
AllowOrigin::list(
|
||||||
|
cors_allow_origin
|
||||||
|
.iter()
|
||||||
|
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Parse Huggingface hub token
|
||||||
|
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
||||||
|
|
||||||
|
// Tokenizer instance
|
||||||
|
// This will only be used to validate payloads
|
||||||
|
let local_path = Path::new(&tokenizer_name);
|
||||||
|
|
||||||
|
// Shared API builder initialization
|
||||||
|
let api_builder = || {
|
||||||
|
let mut builder = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_token(authorization_token);
|
||||||
|
|
||||||
|
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
builder
|
||||||
|
};
|
||||||
|
|
||||||
|
// Decide if we need to use the API based on the revision and local path
|
||||||
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
|
// Initialize API if needed
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum Type {
|
||||||
|
Api(Api),
|
||||||
|
Cache(Cache),
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
let api = if use_api {
|
||||||
|
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||||
|
let cache = Cache::default();
|
||||||
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
|
Type::Cache(cache)
|
||||||
|
} else {
|
||||||
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
match api_builder().build() {
|
||||||
|
Ok(api) => Type::Api(api),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
|
Type::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Type::None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load tokenizer and model info
|
||||||
|
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
||||||
|
Type::None => (
|
||||||
|
Some(local_path.join("tokenizer.json")),
|
||||||
|
Some(local_path.join("config.json")),
|
||||||
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
Type::Api(api) => {
|
||||||
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
|
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||||
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
|
};
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
|
||||||
|
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||||
|
Some(model_info)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
(
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
model_info,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Type::Cache(cache) => {
|
||||||
|
let repo = cache.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
|
));
|
||||||
|
(
|
||||||
|
repo.get("tokenizer.json"),
|
||||||
|
repo.get("config.json"),
|
||||||
|
repo.get("tokenizer_config.json"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tokenizer: Option<Tokenizer> =
|
||||||
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||||
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
|
std::fs::read_to_string(filename)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|c| {
|
||||||
|
let config: Result<Config, _> = serde_json::from_str(c);
|
||||||
|
if let Err(err) = &config {
|
||||||
|
tracing::warn!("Could not parse config {err:?}");
|
||||||
|
}
|
||||||
|
config.ok()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
|
} else {
|
||||||
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
|
};
|
||||||
|
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
tracing::info!("Using config {config:?}");
|
||||||
|
if tokenizer.is_none() {
|
||||||
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
|
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||||
|
None => {
|
||||||
|
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||||
|
true
|
||||||
|
}
|
||||||
|
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Instantiate sharded client from the master unix socket
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(RouterError::Connection)?;
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(RouterError::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_supported_batch_total_tokens = match sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(RouterError::Warmup)?
|
||||||
|
{
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
max_batch_total_tokens
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
max_supported_batch_total_tokens
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
||||||
|
tracing::info!("Connected");
|
||||||
|
|
||||||
|
// Determine the server port based on the feature and environment variable.
|
||||||
|
let port = if cfg!(feature = "google") {
|
||||||
|
std::env::var("AIP_HTTP_PORT")
|
||||||
|
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
||||||
|
.unwrap_or(port)
|
||||||
|
} else {
|
||||||
|
port
|
||||||
|
};
|
||||||
|
|
||||||
|
let addr = match hostname.parse() {
|
||||||
|
Ok(ip) => SocketAddr::new(ip, port),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
server::run(
|
||||||
|
model_info,
|
||||||
|
shard_info,
|
||||||
|
compat_return_full_text,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_supported_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
sharded_client,
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
validation_workers,
|
||||||
|
addr,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
tokenizer_config,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Type alias for generation responses
|
/// Type alias for generation responses
|
||||||
pub(crate) type GenerateStreamResponse = (
|
pub(crate) type GenerateStreamResponse = (
|
||||||
OwnedSemaphorePermit,
|
OwnedSemaphorePermit,
|
||||||
|
@ -1,26 +1,5 @@
|
|||||||
use axum::http::HeaderValue;
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use text_generation_router::{internal_main, RouterError};
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
|
||||||
use opentelemetry::sdk::trace;
|
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
|
||||||
use opentelemetry::sdk::Resource;
|
|
||||||
use opentelemetry::{global, KeyValue};
|
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::BufReader;
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
|
||||||
use text_generation_router::config::Config;
|
|
||||||
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use tower_http::cors::AllowOrigin;
|
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -86,487 +65,36 @@ struct Args {
|
|||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Pattern match configuration
|
|
||||||
let Args {
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
master_shard_uds_path,
|
|
||||||
tokenizer_name,
|
|
||||||
tokenizer_config_path,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
json_output,
|
|
||||||
otlp_endpoint,
|
|
||||||
cors_allow_origin,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
} = args;
|
|
||||||
|
|
||||||
// Launch Tokio runtime
|
internal_main(
|
||||||
init_logging(otlp_endpoint, json_output);
|
args.max_concurrent_requests,
|
||||||
|
args.max_best_of,
|
||||||
// Validate args
|
args.max_stop_sequences,
|
||||||
if max_input_tokens >= max_total_tokens {
|
args.max_top_n_tokens,
|
||||||
return Err(RouterError::ArgumentValidation(
|
args.max_input_tokens,
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
args.max_total_tokens,
|
||||||
));
|
args.waiting_served_ratio,
|
||||||
}
|
args.max_batch_prefill_tokens,
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
args.max_batch_total_tokens,
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
args.max_waiting_tokens,
|
||||||
}
|
args.max_batch_size,
|
||||||
|
args.hostname,
|
||||||
if validation_workers == 0 {
|
args.port,
|
||||||
return Err(RouterError::ArgumentValidation(
|
args.master_shard_uds_path,
|
||||||
"`validation_workers` must be > 0".to_string(),
|
args.tokenizer_name,
|
||||||
));
|
args.tokenizer_config_path,
|
||||||
}
|
args.revision,
|
||||||
|
args.validation_workers,
|
||||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
args.json_output,
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
args.otlp_endpoint,
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
args.cors_allow_origin,
|
||||||
}
|
args.ngrok,
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
args.ngrok_authtoken,
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
args.ngrok_edge,
|
||||||
}
|
args.messages_api_enabled,
|
||||||
}
|
args.disable_grammar_support,
|
||||||
|
args.max_client_batch_size,
|
||||||
// CORS allowed origins
|
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
|
||||||
// Finally, convert to AllowOrigin
|
|
||||||
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
|
||||||
AllowOrigin::list(
|
|
||||||
cors_allow_origin
|
|
||||||
.iter()
|
|
||||||
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
|
||||||
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
|
||||||
|
|
||||||
// Tokenizer instance
|
|
||||||
// This will only be used to validate payloads
|
|
||||||
let local_path = Path::new(&tokenizer_name);
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
|
||||||
let api_builder = || {
|
|
||||||
let mut builder = ApiBuilder::new()
|
|
||||||
.with_progress(false)
|
|
||||||
.with_token(authorization_token);
|
|
||||||
|
|
||||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
|
||||||
builder = builder.with_cache_dir(cache_dir.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
builder
|
|
||||||
};
|
|
||||||
|
|
||||||
// Decide if we need to use the API based on the revision and local path
|
|
||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
|
||||||
|
|
||||||
// Initialize API if needed
|
|
||||||
#[derive(Clone)]
|
|
||||||
enum Type {
|
|
||||||
Api(Api),
|
|
||||||
Cache(Cache),
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
let api = if use_api {
|
|
||||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
|
||||||
let cache = Cache::default();
|
|
||||||
tracing::warn!("Offline mode active using cache defaults");
|
|
||||||
Type::Cache(cache)
|
|
||||||
} else {
|
|
||||||
tracing::info!("Using the Hugging Face API");
|
|
||||||
match api_builder().build() {
|
|
||||||
Ok(api) => Type::Api(api),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
|
||||||
Type::None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Type::None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load tokenizer and model info
|
|
||||||
let (
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
) = match api {
|
|
||||||
Type::None => (
|
|
||||||
Some(local_path.join("tokenizer.json")),
|
|
||||||
Some(local_path.join("config.json")),
|
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
|
||||||
Some(local_path.join("processor_config.json")),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
Type::Api(api) => {
|
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
|
||||||
Some(model_info)
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Type::Cache(cache) => {
|
|
||||||
let repo = cache.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
(
|
|
||||||
repo.get("tokenizer.json"),
|
|
||||||
repo.get("config.json"),
|
|
||||||
repo.get("tokenizer_config.json"),
|
|
||||||
repo.get("processor_config.json"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let tokenizer: Option<Tokenizer> =
|
|
||||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
|
||||||
std::fs::read_to_string(filename)
|
|
||||||
.ok()
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|c| {
|
|
||||||
let config: Result<Config, _> = serde_json::from_str(c);
|
|
||||||
if let Err(err) = &config {
|
|
||||||
tracing::warn!("Could not parse config {err:?}");
|
|
||||||
}
|
|
||||||
config.ok()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
{
|
|
||||||
HubTokenizerConfig::from_file(filename)
|
|
||||||
} else {
|
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
};
|
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
let processor_config = processor_config_filename
|
|
||||||
.and_then(HubProcessorConfig::from_file)
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
if tokenizer.is_none() {
|
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
|
||||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
|
||||||
None => {
|
|
||||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
|
||||||
true
|
|
||||||
}
|
|
||||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
|
||||||
.await
|
|
||||||
.map_err(RouterError::Connection)?;
|
|
||||||
// Clear the cache; useful if the webserver rebooted
|
|
||||||
sharded_client
|
|
||||||
.clear_cache(None)
|
|
||||||
.await
|
|
||||||
.map_err(RouterError::Cache)?;
|
|
||||||
// Get info from the shard
|
|
||||||
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
|
||||||
|
|
||||||
// Warmup model
|
|
||||||
tracing::info!("Warming up model");
|
|
||||||
let max_supported_batch_total_tokens = match sharded_client
|
|
||||||
.warmup(
|
|
||||||
max_input_tokens as u32,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_total_tokens as u32,
|
|
||||||
max_batch_size,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(RouterError::Warmup)?
|
|
||||||
{
|
|
||||||
// Older models do not support automatic max-batch-total-tokens
|
|
||||||
None => {
|
|
||||||
let max_batch_total_tokens = max_batch_total_tokens
|
|
||||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
|
||||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
|
||||||
max_batch_total_tokens
|
|
||||||
}
|
|
||||||
// Flash attention models return their max supported total tokens
|
|
||||||
Some(max_supported_batch_total_tokens) => {
|
|
||||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
|
||||||
if max_batch_total_tokens.is_some() {
|
|
||||||
tracing::warn!(
|
|
||||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
|
||||||
Attention models."
|
|
||||||
);
|
|
||||||
tracing::warn!(
|
|
||||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
max_supported_batch_total_tokens
|
|
||||||
}
|
|
||||||
};
|
|
||||||
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
|
||||||
tracing::info!("Connected");
|
|
||||||
|
|
||||||
// Determine the server port based on the feature and environment variable.
|
|
||||||
let port = if cfg!(feature = "google") {
|
|
||||||
std::env::var("AIP_HTTP_PORT")
|
|
||||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
|
||||||
.unwrap_or(port)
|
|
||||||
} else {
|
|
||||||
port
|
|
||||||
};
|
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
|
||||||
Ok(ip) => SocketAddr::new(ip, port),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
|
||||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Run server
|
|
||||||
server::run(
|
|
||||||
model_info,
|
|
||||||
shard_info,
|
|
||||||
compat_return_full_text,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_supported_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
sharded_client,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
validation_workers,
|
|
||||||
addr,
|
|
||||||
cors_allow_origin,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
tokenizer_config,
|
|
||||||
processor_config,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
|
||||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
|
||||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
|
||||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
|
||||||
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|
||||||
let mut layers = Vec::new();
|
|
||||||
|
|
||||||
// STDOUT/STDERR layer
|
|
||||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
|
||||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
|
||||||
.with_file(true)
|
|
||||||
.with_ansi(ansi)
|
|
||||||
.with_line_number(true);
|
|
||||||
|
|
||||||
let fmt_layer = match json_output {
|
|
||||||
true => fmt_layer.json().flatten_event(true).boxed(),
|
|
||||||
false => fmt_layer.boxed(),
|
|
||||||
};
|
|
||||||
layers.push(fmt_layer);
|
|
||||||
|
|
||||||
// OpenTelemetry tracing layer
|
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
|
||||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
||||||
|
|
||||||
let tracer = opentelemetry_otlp::new_pipeline()
|
|
||||||
.tracing()
|
|
||||||
.with_exporter(
|
|
||||||
opentelemetry_otlp::new_exporter()
|
|
||||||
.tonic()
|
|
||||||
.with_endpoint(otlp_endpoint),
|
|
||||||
)
|
|
||||||
.with_trace_config(
|
|
||||||
trace::config()
|
|
||||||
.with_resource(Resource::new(vec![KeyValue::new(
|
|
||||||
"service.name",
|
|
||||||
"text-generation-inference.router",
|
|
||||||
)]))
|
|
||||||
.with_sampler(Sampler::AlwaysOn),
|
|
||||||
)
|
|
||||||
.install_batch(opentelemetry::runtime::Tokio);
|
|
||||||
|
|
||||||
if let Ok(tracer) = tracer {
|
|
||||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
|
||||||
init_tracing_opentelemetry::init_propagator().unwrap();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
|
||||||
let varname = "LOG_LEVEL";
|
|
||||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
|
||||||
// Override to avoid simple logs to be spammed with tokio level informations
|
|
||||||
let log_level = match &log_level[..] {
|
|
||||||
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
|
||||||
"info" => "text_generation_launcher=info,text_generation_router=info",
|
|
||||||
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
|
||||||
log_level => log_level,
|
|
||||||
};
|
|
||||||
EnvFilter::builder()
|
|
||||||
.with_default_directive(LevelFilter::INFO.into())
|
|
||||||
.parse_lossy(log_level)
|
|
||||||
} else {
|
|
||||||
EnvFilter::new("info")
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
|
||||||
.with(env_filter)
|
|
||||||
.with(layers)
|
|
||||||
.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
|
||||||
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|
||||||
let response = api.info_request().send().await.ok()?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
let hub_model_info: HubModelInfo =
|
|
||||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
|
||||||
if let Some(sha) = &hub_model_info.sha {
|
|
||||||
tracing::info!(
|
|
||||||
"Serving revision {sha} of model {}",
|
|
||||||
hub_model_info.model_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Some(hub_model_info)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get base tokenizer
|
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of `User`.
|
|
||||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
|
||||||
|
|
||||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
|
||||||
let api_base_repo = api.repo(Repo::with_revision(
|
|
||||||
base_model_id.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
"main".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
api_base_repo.get("tokenizer.json").await.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get tokenizer_config from the Huggingface Hub
|
|
||||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(tokenizer_config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
|
||||||
e
|
|
||||||
})
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
Some(tokenizer_config)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
enum RouterError {
|
|
||||||
#[error("Argument validation error: {0}")]
|
|
||||||
ArgumentValidation(String),
|
|
||||||
#[error("Unable to connect to the Python model shards: {0}")]
|
|
||||||
Connection(ClientError),
|
|
||||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
|
||||||
Cache(ClientError),
|
|
||||||
#[error("Unable to get the Python model shards info: {0}")]
|
|
||||||
Info(ClientError),
|
|
||||||
#[error("Unable to warmup the Python model shards: {0}")]
|
|
||||||
Warmup(ClientError),
|
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
|
||||||
Tokio(#[from] std::io::Error),
|
|
||||||
#[error("Axum webserver failed: {0}")]
|
|
||||||
Axum(#[from] axum::BoxError),
|
|
||||||
}
|
|
||||||
|
72
tgi/.gitignore
vendored
Normal file
72
tgi/.gitignore
vendored
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
/target
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
.pytest_cache/
|
||||||
|
*.py[cod]
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
.venv/
|
||||||
|
env/
|
||||||
|
bin/
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
include/
|
||||||
|
man/
|
||||||
|
venv/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
pip-selfcheck.json
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
|
||||||
|
# Mr Developer
|
||||||
|
.mr.developer.cfg
|
||||||
|
.project
|
||||||
|
.pydevproject
|
||||||
|
|
||||||
|
# Rope
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# VSCode
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# Pyenv
|
||||||
|
.python-version
|
16
tgi/Cargo.toml
Normal file
16
tgi/Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
[package]
|
||||||
|
name = "tgi"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
[lib]
|
||||||
|
name = "tgi"
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
pyo3 = { version = "0.20.0", features = ["extension-module"] }
|
||||||
|
pyo3-asyncio = { version = "0.20.0", features = ["tokio-runtime"] }
|
||||||
|
tokio = "1.4"
|
||||||
|
text-generation-router = { path = "../router" }
|
||||||
|
text-generation-launcher = { path = "../launcher" }
|
6
tgi/Makefile
Normal file
6
tgi/Makefile
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
|
||||||
|
build:
|
||||||
|
maturin build
|
||||||
|
|
||||||
|
install: build
|
||||||
|
pip install -e .
|
47
tgi/README.md
Normal file
47
tgi/README.md
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# TGI (Python Package)
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> This is an experimental package and intended for research purposes only. The package is likely to change and should only be used for testing and development.
|
||||||
|
|
||||||
|
`tgi` is a simple Python package that wraps the `text-generation-server` and `text-generation-launcher` packages. It provides a simple interface to the text generation server.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make install
|
||||||
|
# this compiles the code and runs pip install for `tgi`
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
See the full example in the [`app.py`](./app.py) file.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tgi import TGI
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
import time
|
||||||
|
|
||||||
|
llm = TGI(model_id="google/paligemma-3b-mix-224")
|
||||||
|
|
||||||
|
# ✂️ startup logic snipped
|
||||||
|
print("Model is ready!")
|
||||||
|
|
||||||
|
client = InferenceClient("http://localhost:3000")
|
||||||
|
generated = client.text_generation("What are the main characteristics of a cat?")
|
||||||
|
print(generated)
|
||||||
|
|
||||||
|
# Cats are known for their independent nature, curious minds, and affectionate nature. Here are the main characteristics of a cat...
|
||||||
|
|
||||||
|
llm.close()
|
||||||
|
```
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
Technically this is a `pyo3` package that wraps the `text-generation-server` and `text-generation-launcher` packages, and slightly modifies the launcher to rely on the interal code rather than launch an external binary.
|
||||||
|
|
||||||
|
## Known issues/limitations
|
||||||
|
|
||||||
|
- [ ] server does not gracefully handle shutdowns (trying to avoid python context for better notebook dev experience)
|
||||||
|
- [ ] issues with tracing (launcher and router should share tracer)
|
||||||
|
- [ ] text-generation-server is not integrated and still relies on the external install
|
||||||
|
- [ ] not all parameters are exposed/passed through
|
||||||
|
- [ ] general cleanup and refactoring needed
|
||||||
|
- [ ] review naming and explore pushing to PyPi
|
38
tgi/app.py
Normal file
38
tgi/app.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from tgi import TGI
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
import time
|
||||||
|
|
||||||
|
llm = TGI(model_id="google/paligemma-3b-mix-224")
|
||||||
|
client = InferenceClient("http://localhost:3000")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print("Waiting for the model to be ready...")
|
||||||
|
try:
|
||||||
|
time.sleep(5)
|
||||||
|
generated = client.text_generation("What is Deep Learning?")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print("Model is ready!")
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# do a couple of inference requests
|
||||||
|
print("Generating text...")
|
||||||
|
generated = client.text_generation("Where is the capital of France?")
|
||||||
|
print(generated)
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
generated = client.text_generation(
|
||||||
|
"What can you tell me about the history of the United States?"
|
||||||
|
)
|
||||||
|
print(generated)
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
generated = client.text_generation("What are the main characteristics of a cat?")
|
||||||
|
print(generated)
|
||||||
|
|
||||||
|
llm.close()
|
15
tgi/pyproject.toml
Normal file
15
tgi/pyproject.toml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["maturin>=1.5,<2.0"]
|
||||||
|
build-backend = "maturin"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "tgi"
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Rust",
|
||||||
|
"Programming Language :: Python :: Implementation :: CPython",
|
||||||
|
"Programming Language :: Python :: Implementation :: PyPy",
|
||||||
|
]
|
||||||
|
dynamic = ["version"]
|
||||||
|
[tool.maturin]
|
||||||
|
features = ["pyo3/extension-module"]
|
455
tgi/src/lib.rs
Normal file
455
tgi/src/lib.rs
Normal file
@ -0,0 +1,455 @@
|
|||||||
|
use pyo3::{prelude::*, wrap_pyfunction};
|
||||||
|
use text_generation_launcher::{launcher_main, launcher_main_without_server};
|
||||||
|
use text_generation_router::internal_main;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
sharded,
|
||||||
|
num_shard,
|
||||||
|
_quantize,
|
||||||
|
speculate,
|
||||||
|
_dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
cuda_graphs,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
shard_uds_path,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
|
disable_custom_kernels,
|
||||||
|
cuda_memory_fraction,
|
||||||
|
_rope_scaling,
|
||||||
|
rope_factor,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
tokenizer_config_path,
|
||||||
|
disable_grammar_support,
|
||||||
|
env,
|
||||||
|
max_client_batch_size,
|
||||||
|
))]
|
||||||
|
fn rust_launcher(
|
||||||
|
py: Python<'_>,
|
||||||
|
model_id: String,
|
||||||
|
revision: Option<String>,
|
||||||
|
validation_workers: usize,
|
||||||
|
sharded: Option<bool>,
|
||||||
|
num_shard: Option<usize>,
|
||||||
|
_quantize: Option<String>, // Option<Quantization>,
|
||||||
|
speculate: Option<usize>,
|
||||||
|
_dtype: Option<String>, // Option<Dtype>,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
max_input_tokens: Option<usize>,
|
||||||
|
max_input_length: Option<usize>,
|
||||||
|
max_total_tokens: Option<usize>,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: Option<u32>,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
cuda_graphs: Option<Vec<usize>>,
|
||||||
|
hostname: String,
|
||||||
|
port: u16,
|
||||||
|
shard_uds_path: String,
|
||||||
|
master_addr: String,
|
||||||
|
master_port: usize,
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
|
disable_custom_kernels: bool,
|
||||||
|
cuda_memory_fraction: f32,
|
||||||
|
_rope_scaling: Option<f32>, // Option<RopeScaling>,
|
||||||
|
rope_factor: Option<f32>,
|
||||||
|
json_output: bool,
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
cors_allow_origin: Vec<String>,
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
|
ngrok: bool,
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
env: bool,
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
) -> PyResult<&PyAny> {
|
||||||
|
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||||
|
launcher_main(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
sharded,
|
||||||
|
num_shard,
|
||||||
|
None,
|
||||||
|
speculate,
|
||||||
|
None,
|
||||||
|
trust_remote_code,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
cuda_graphs,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
shard_uds_path,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
|
disable_custom_kernels,
|
||||||
|
cuda_memory_fraction,
|
||||||
|
None,
|
||||||
|
rope_factor,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
tokenizer_config_path,
|
||||||
|
disable_grammar_support,
|
||||||
|
env,
|
||||||
|
max_client_batch_size,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(Python::with_gil(|py| py.None()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
sharded,
|
||||||
|
num_shard,
|
||||||
|
_quantize,
|
||||||
|
speculate,
|
||||||
|
_dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
cuda_graphs,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
shard_uds_path,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
|
disable_custom_kernels,
|
||||||
|
cuda_memory_fraction,
|
||||||
|
_rope_scaling,
|
||||||
|
rope_factor,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
tokenizer_config_path,
|
||||||
|
disable_grammar_support,
|
||||||
|
env,
|
||||||
|
max_client_batch_size,
|
||||||
|
))]
|
||||||
|
fn fully_packaged(
|
||||||
|
py: Python<'_>,
|
||||||
|
model_id: String,
|
||||||
|
revision: Option<String>,
|
||||||
|
validation_workers: usize,
|
||||||
|
sharded: Option<bool>,
|
||||||
|
num_shard: Option<usize>,
|
||||||
|
_quantize: Option<String>, // Option<Quantization>,
|
||||||
|
speculate: Option<usize>,
|
||||||
|
_dtype: Option<String>, // Option<Dtype>,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
max_input_tokens: Option<usize>,
|
||||||
|
max_input_length: Option<usize>,
|
||||||
|
max_total_tokens: Option<usize>,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: Option<u32>,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
cuda_graphs: Option<Vec<usize>>,
|
||||||
|
hostname: String,
|
||||||
|
port: u16,
|
||||||
|
shard_uds_path: String,
|
||||||
|
master_addr: String,
|
||||||
|
master_port: usize,
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
|
disable_custom_kernels: bool,
|
||||||
|
cuda_memory_fraction: f32,
|
||||||
|
_rope_scaling: Option<f32>, // Option<RopeScaling>,
|
||||||
|
rope_factor: Option<f32>,
|
||||||
|
json_output: bool,
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
cors_allow_origin: Vec<String>,
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
|
ngrok: bool,
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
env: bool,
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
) -> PyResult<&PyAny> {
|
||||||
|
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||||
|
use std::thread;
|
||||||
|
use tokio::runtime::Runtime;
|
||||||
|
|
||||||
|
let model_id_clone = model_id.clone();
|
||||||
|
let max_concurrent_requests_clone = max_concurrent_requests;
|
||||||
|
let max_best_of_clone = max_best_of;
|
||||||
|
let max_stop_sequences_clone = max_stop_sequences;
|
||||||
|
let max_top_n_tokens_clone = max_top_n_tokens;
|
||||||
|
let max_input_tokens_clone = max_input_tokens.unwrap_or(1024);
|
||||||
|
let max_total_tokens_clone = max_total_tokens.unwrap_or(2048);
|
||||||
|
let waiting_served_ratio_clone = waiting_served_ratio;
|
||||||
|
|
||||||
|
let max_batch_prefill_tokens_clone = max_batch_prefill_tokens.unwrap_or(4096);
|
||||||
|
let max_batch_total_tokens_clone = max_batch_total_tokens;
|
||||||
|
let max_waiting_tokens_clone = max_waiting_tokens;
|
||||||
|
let max_batch_size_clone = max_batch_size;
|
||||||
|
let hostname_clone = hostname.clone();
|
||||||
|
let port_clone = port;
|
||||||
|
|
||||||
|
// TODO: fix this
|
||||||
|
let _shard_uds_path_clone = shard_uds_path.clone();
|
||||||
|
|
||||||
|
let tokenizer_config_path = tokenizer_config_path.clone();
|
||||||
|
let revision = revision.clone();
|
||||||
|
let validation_workers = validation_workers;
|
||||||
|
let json_output = json_output;
|
||||||
|
|
||||||
|
let otlp_endpoint = otlp_endpoint.clone();
|
||||||
|
let cors_allow_origin = cors_allow_origin.clone();
|
||||||
|
let ngrok = ngrok;
|
||||||
|
let ngrok_authtoken = ngrok_authtoken.clone();
|
||||||
|
let ngrok_edge = ngrok_edge.clone();
|
||||||
|
let messages_api_enabled = true;
|
||||||
|
let disable_grammar_support = disable_grammar_support;
|
||||||
|
let max_client_batch_size = max_client_batch_size;
|
||||||
|
|
||||||
|
let ngrok_clone = ngrok;
|
||||||
|
let ngrok_authtoken_clone = ngrok_authtoken.clone();
|
||||||
|
let ngrok_edge_clone = ngrok_edge.clone();
|
||||||
|
let messages_api_enabled_clone = messages_api_enabled;
|
||||||
|
let disable_grammar_support_clone = disable_grammar_support;
|
||||||
|
let max_client_batch_size_clone = max_client_batch_size;
|
||||||
|
|
||||||
|
let tokenizer_config_path_clone = tokenizer_config_path.clone();
|
||||||
|
let revision_clone = revision.clone();
|
||||||
|
let validation_workers_clone = validation_workers;
|
||||||
|
let json_output_clone = json_output;
|
||||||
|
let otlp_endpoint_clone = otlp_endpoint.clone();
|
||||||
|
|
||||||
|
let webserver_callback = move || {
|
||||||
|
let handle = thread::spawn(move || {
|
||||||
|
let rt = Runtime::new().unwrap();
|
||||||
|
rt.block_on(async {
|
||||||
|
internal_main(
|
||||||
|
max_concurrent_requests_clone,
|
||||||
|
max_best_of_clone,
|
||||||
|
max_stop_sequences_clone,
|
||||||
|
max_top_n_tokens_clone,
|
||||||
|
max_input_tokens_clone,
|
||||||
|
max_total_tokens_clone,
|
||||||
|
waiting_served_ratio_clone,
|
||||||
|
max_batch_prefill_tokens_clone,
|
||||||
|
max_batch_total_tokens_clone,
|
||||||
|
max_waiting_tokens_clone,
|
||||||
|
max_batch_size_clone,
|
||||||
|
hostname_clone,
|
||||||
|
port_clone,
|
||||||
|
"/tmp/text-generation-server-0".to_string(),
|
||||||
|
model_id_clone,
|
||||||
|
tokenizer_config_path_clone,
|
||||||
|
revision_clone,
|
||||||
|
validation_workers_clone,
|
||||||
|
json_output_clone,
|
||||||
|
otlp_endpoint_clone,
|
||||||
|
None,
|
||||||
|
ngrok_clone,
|
||||||
|
ngrok_authtoken_clone,
|
||||||
|
ngrok_edge_clone,
|
||||||
|
messages_api_enabled_clone,
|
||||||
|
disable_grammar_support_clone,
|
||||||
|
max_client_batch_size_clone,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
});
|
||||||
|
match handle.join() {
|
||||||
|
Ok(_) => println!("Server exited successfully"),
|
||||||
|
Err(e) => println!("Server exited with error: {:?}", e),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
// parse the arguments and run the main function
|
||||||
|
launcher_main_without_server(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
sharded,
|
||||||
|
num_shard,
|
||||||
|
None,
|
||||||
|
speculate,
|
||||||
|
None,
|
||||||
|
trust_remote_code,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
cuda_graphs,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
shard_uds_path,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
|
disable_custom_kernels,
|
||||||
|
cuda_memory_fraction,
|
||||||
|
None,
|
||||||
|
rope_factor,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
tokenizer_config_path,
|
||||||
|
disable_grammar_support,
|
||||||
|
env,
|
||||||
|
max_client_batch_size,
|
||||||
|
Box::new(webserver_callback),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(Python::with_gil(|py| py.None()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Asynchronous sleep function.
|
||||||
|
#[pyfunction]
|
||||||
|
fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
|
||||||
|
pyo3_asyncio::tokio::future_into_py(py, async {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(20)).await;
|
||||||
|
Ok(Python::with_gil(|py| py.None()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove hardcoding
|
||||||
|
#[pyfunction]
|
||||||
|
fn rust_server(py: Python<'_>) -> PyResult<&PyAny> {
|
||||||
|
pyo3_asyncio::tokio::future_into_py(py, async {
|
||||||
|
let _ = internal_main(
|
||||||
|
128, // max_concurrent_requests: usize,
|
||||||
|
2, // max_best_of: usize,
|
||||||
|
4, // max_stop_sequences: usize,
|
||||||
|
5, // max_top_n_tokens: u32,
|
||||||
|
1024, // max_input_tokens: usize,
|
||||||
|
2048, // max_total_tokens: usize,
|
||||||
|
1.2, // waiting_served_ratio: f32,
|
||||||
|
4096, // max_batch_prefill_tokens: u32,
|
||||||
|
None, // max_batch_total_tokens: Option<u32>,
|
||||||
|
20, // max_waiting_tokens: usize,
|
||||||
|
None, // max_batch_size: Option<usize>,
|
||||||
|
"0.0.0.0".to_string(), // hostname: String,
|
||||||
|
3000, // port: u16,
|
||||||
|
"/tmp/text-generation-server-0".to_string(), // master_shard_uds_path: String,
|
||||||
|
"llava-hf/llava-v1.6-mistral-7b-hf".to_string(), // tokenizer_name: String,
|
||||||
|
None, // tokenizer_config_path: Option<String>,
|
||||||
|
None, // revision: Option<String>,
|
||||||
|
2, // validation_workers: usize,
|
||||||
|
false, // json_output: bool,
|
||||||
|
None, // otlp_endpoint: Option<String>,
|
||||||
|
None, // cors_allow_origin: Option<Vec<String>>,
|
||||||
|
false, // ngrok: bool,
|
||||||
|
None, // ngrok_authtoken: Option<String>,
|
||||||
|
None, // ngrok_edge: Option<String>,
|
||||||
|
false, // messages_api_enabled: bool,
|
||||||
|
false, // disable_grammar_support: bool,
|
||||||
|
4, // max_client_batch_size: usize,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
Ok(Python::with_gil(|py| py.None()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymodule]
|
||||||
|
fn tgi(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_function(wrap_pyfunction!(rust_sleep, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(rust_server, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(rust_launcher, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(fully_packaged, m)?)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
132
tgi/tgi/__init__.py
Normal file
132
tgi/tgi/__init__.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from .tgi import *
|
||||||
|
import threading
|
||||||
|
from tgi import rust_launcher, rust_sleep, fully_packaged
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# add the rust_launcher coroutine to the __all__ list
|
||||||
|
__doc__ = tgi.__doc__
|
||||||
|
if hasattr(tgi, "__all__"):
|
||||||
|
__all__ = tgi.__all__
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Args:
|
||||||
|
model_id = "google/gemma-2b-it"
|
||||||
|
revision = None
|
||||||
|
validation_workers = 2
|
||||||
|
sharded = None
|
||||||
|
num_shard = None
|
||||||
|
quantize = None
|
||||||
|
speculate = None
|
||||||
|
dtype = None
|
||||||
|
trust_remote_code = True
|
||||||
|
max_concurrent_requests = 128
|
||||||
|
max_best_of = 2
|
||||||
|
max_stop_sequences = 4
|
||||||
|
max_top_n_tokens = 5
|
||||||
|
max_input_tokens = None
|
||||||
|
max_input_length = None
|
||||||
|
max_total_tokens = None
|
||||||
|
waiting_served_ratio = 0.3
|
||||||
|
max_batch_prefill_tokens = None
|
||||||
|
max_batch_total_tokens = None
|
||||||
|
max_waiting_tokens = 20
|
||||||
|
max_batch_size = None
|
||||||
|
cuda_graphs = None
|
||||||
|
hostname = "0.0.0.0"
|
||||||
|
port = 3000
|
||||||
|
shard_uds_path = "/tmp/text-generation-server"
|
||||||
|
master_addr = "localhost"
|
||||||
|
master_port = 29500
|
||||||
|
huggingface_hub_cache = None
|
||||||
|
weights_cache_override = None
|
||||||
|
disable_custom_kernels = False
|
||||||
|
cuda_memory_fraction = 1.0
|
||||||
|
rope_scaling = None
|
||||||
|
rope_factor = None
|
||||||
|
json_output = False
|
||||||
|
otlp_endpoint = None
|
||||||
|
cors_allow_origin = []
|
||||||
|
watermark_gamma = None
|
||||||
|
watermark_delta = None
|
||||||
|
ngrok = False
|
||||||
|
ngrok_authtoken = None
|
||||||
|
ngrok_edge = None
|
||||||
|
tokenizer_config_path = None
|
||||||
|
disable_grammar_support = False
|
||||||
|
env = False
|
||||||
|
max_client_batch_size = 4
|
||||||
|
|
||||||
|
|
||||||
|
class TGI(object):
|
||||||
|
# only allow a limited set of arguments for now
|
||||||
|
def __init__(self, model_id=None):
|
||||||
|
app_args = Args()
|
||||||
|
if model_id:
|
||||||
|
app_args.model_id = model_id
|
||||||
|
|
||||||
|
print(asdict(app_args))
|
||||||
|
self.thread = threading.Thread(target=self.run, args=(asdict(app_args),))
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
async def runit(self, args: dict):
|
||||||
|
print(args)
|
||||||
|
args = Args(**args)
|
||||||
|
try:
|
||||||
|
await fully_packaged(
|
||||||
|
args.model_id,
|
||||||
|
args.revision,
|
||||||
|
args.validation_workers,
|
||||||
|
args.sharded,
|
||||||
|
args.num_shard,
|
||||||
|
args.quantize,
|
||||||
|
args.speculate,
|
||||||
|
args.dtype,
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.max_concurrent_requests,
|
||||||
|
args.max_best_of,
|
||||||
|
args.max_stop_sequences,
|
||||||
|
args.max_top_n_tokens,
|
||||||
|
args.max_input_tokens,
|
||||||
|
args.max_input_length,
|
||||||
|
args.max_total_tokens,
|
||||||
|
args.waiting_served_ratio,
|
||||||
|
args.max_batch_prefill_tokens,
|
||||||
|
args.max_batch_total_tokens,
|
||||||
|
args.max_waiting_tokens,
|
||||||
|
args.max_batch_size,
|
||||||
|
args.cuda_graphs,
|
||||||
|
args.hostname,
|
||||||
|
args.port,
|
||||||
|
args.shard_uds_path,
|
||||||
|
args.master_addr,
|
||||||
|
args.master_port,
|
||||||
|
args.huggingface_hub_cache,
|
||||||
|
args.weights_cache_override,
|
||||||
|
args.disable_custom_kernels,
|
||||||
|
args.cuda_memory_fraction,
|
||||||
|
args.rope_scaling,
|
||||||
|
args.rope_factor,
|
||||||
|
args.json_output,
|
||||||
|
args.otlp_endpoint,
|
||||||
|
args.cors_allow_origin,
|
||||||
|
args.watermark_gamma,
|
||||||
|
args.watermark_delta,
|
||||||
|
args.ngrok,
|
||||||
|
args.ngrok_authtoken,
|
||||||
|
args.ngrok_edge,
|
||||||
|
args.tokenizer_config_path,
|
||||||
|
args.disable_grammar_support,
|
||||||
|
args.env,
|
||||||
|
args.max_client_batch_size,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
def run(self, args: dict):
|
||||||
|
asyncio.run(self.runit(args))
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.thread.join()
|
Loading…
Reference in New Issue
Block a user