From 0e5220d704fc6468437370329eaa0102844658ae Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 16 May 2024 19:47:12 +0000 Subject: [PATCH] feat: experimental python packaging and interface --- Cargo.lock | 102 +++ Cargo.toml | 3 +- launcher/src/lib.rs | 2026 ++++++++++++++++++++++++++++++++++++++++++ launcher/src/main.rs | 1627 +-------------------------------- router/src/lib.rs | 476 ++++++++++ router/src/main.rs | 530 +---------- tgi/.gitignore | 72 ++ tgi/Cargo.toml | 16 + tgi/Makefile | 6 + tgi/README.md | 47 + tgi/app.py | 38 + tgi/pyproject.toml | 15 + tgi/src/lib.rs | 455 ++++++++++ tgi/tgi/__init__.py | 132 +++ 14 files changed, 3465 insertions(+), 2080 deletions(-) create mode 100644 launcher/src/lib.rs create mode 100644 tgi/.gitignore create mode 100644 tgi/Cargo.toml create mode 100644 tgi/Makefile create mode 100644 tgi/README.md create mode 100644 tgi/app.py create mode 100644 tgi/pyproject.toml create mode 100644 tgi/src/lib.rs create mode 100644 tgi/tgi/__init__.py diff --git a/Cargo.lock b/Cargo.lock index d58f4cb1..248526e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1783,6 +1783,15 @@ version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "metrics" version = "0.21.1" @@ -2665,6 +2674,82 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "pyo3" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-asyncio" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea6b68e93db3622f3bb3bf363246cf948ed5375afe7abff98ccbdd50b184995" +dependencies = [ + "futures", + "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.60", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.60", +] + [[package]] name = "qoi" version = "0.4.1" @@ -3627,6 +3712,17 @@ dependencies = [ "vergen", ] +[[package]] +name = "tgi" +version = "0.1.0" +dependencies = [ + "pyo3", + "pyo3-asyncio", + "text-generation-launcher", + "text-generation-router", + "tokio", +] + [[package]] name = "thiserror" version = "1.0.61" @@ -4186,6 +4282,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index c5c6ca6e..3166149f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,8 @@ members = [ "router", "router/client", "router/grpc-metadata", - "launcher" + "launcher", + "tgi" ] resolver = "2" diff --git a/launcher/src/lib.rs b/launcher/src/lib.rs new file mode 100644 index 00000000..d917a504 --- /dev/null +++ b/launcher/src/lib.rs @@ -0,0 +1,2026 @@ +use clap::{Parser, ValueEnum}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use nix::sys::signal::{self, Signal}; +use nix::unistd::Pid; +use serde::Deserialize; +use std::env; +use std::ffi::OsString; +use std::io::{BufRead, BufReader, Lines}; +use std::os::unix::process::{CommandExt, ExitStatusExt}; +use std::path::Path; +use std::process::{Child, Command, ExitStatus, Stdio}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::TryRecvError; +use std::sync::{mpsc, Arc}; +use std::thread; +use std::thread::sleep; +use std::time::{Duration, Instant}; +use std::{fs, io}; +use thiserror::Error; +use tracing_subscriber::EnvFilter; + +mod env_runtime; + +#[derive(Deserialize)] +struct Config { + max_position_embeddings: Option, + max_seq_len: Option, +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum Quantization { + /// 4 bit quantization. Requires a specific AWQ quantized model: + /// . + /// Should replace GPTQ models wherever possible because of the better latency + Awq, + /// 8 bit quantization, doesn't require specific model. + /// Should be a drop-in replacement to bitsandbytes with much better performance. + /// Kernels are from + Eetq, + /// 4 bit quantization. Requires a specific GTPQ quantized model: . + /// text-generation-inference will use exllama (faster) kernels wherever possible, and use + /// triton kernel (wider support) when it's not. + /// AWQ has faster kernels. + Gptq, + /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, + /// but it is known that the model will be much slower to run than the native f16. + #[deprecated( + since = "1.1.0", + note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + )] + Bitsandbytes, + /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, + /// but it is known that the model will be much slower to run than the native f16. + BitsandbytesNF4, + /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better + /// perplexity performance for you model + BitsandbytesFP4, + /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above + /// This dtype has native ops should be the fastest if available. + /// This is currently not the fastest because of local unpacking + padding to satisfy matrix + /// multiplication limitations. + Fp8, +} + +impl std::fmt::Display for Quantization { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + #[allow(deprecated)] + // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases + Quantization::Bitsandbytes => { + write!(f, "bitsandbytes") + } + Quantization::BitsandbytesNF4 => { + write!(f, "bitsandbytes-nf4") + } + Quantization::BitsandbytesFP4 => { + write!(f, "bitsandbytes-fp4") + } + Quantization::Gptq => { + write!(f, "gptq") + } + Quantization::Awq => { + write!(f, "awq") + } + Quantization::Eetq => { + write!(f, "eetq") + } + Quantization::Fp8 => { + write!(f, "fp8") + } + } + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum Dtype { + Float16, + #[clap(name = "bfloat16")] + BFloat16, +} + +impl std::fmt::Display for Dtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Dtype::Float16 => { + write!(f, "float16") + } + Dtype::BFloat16 => { + write!(f, "bfloat16") + } + } + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum RopeScaling { + Linear, + Dynamic, +} + +impl std::fmt::Display for RopeScaling { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + RopeScaling::Linear => { + write!(f, "linear") + } + RopeScaling::Dynamic => { + write!(f, "dynamic") + } + } + } +} + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +pub struct Args { + /// The name of the model to load. + /// Can be a MODEL_ID as listed on like + /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`. + /// Or it can be a local directory containing the necessary files + /// as saved by `save_pretrained(...)` methods of transformers + #[clap(default_value = "bigscience/bloom-560m", long, env)] + pub model_id: String, + + /// The actual revision of the model if you're referring to a model + /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. + #[clap(long, env)] + pub revision: Option, + + /// The number of tokenizer workers used for payload validation and truncation inside the + /// router. + #[clap(default_value = "2", long, env)] + pub validation_workers: usize, + + /// Whether to shard the model across multiple GPUs + /// By default text-generation-inference will use all available GPUs to run + /// the model. Setting it to `false` deactivates `num_shard`. + #[clap(long, env)] + pub sharded: Option, + + /// The number of shards to use if you don't want to use all GPUs on a given machine. + /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2` + /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to + /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance. + #[clap(long, env)] + pub num_shard: Option, + + /// Whether you want the model to be quantized. + #[clap(long, env, value_enum)] + pub quantize: Option, + + /// The number of input_ids to speculate on + /// If using a medusa model, the heads will be picked up automatically + /// Other wise, it will use n-gram speculation which is relatively free + /// in terms of compute, but the speedup heavily depends on the task. + #[clap(long, env)] + pub speculate: Option, + + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. + #[clap(long, env, value_enum)] + pub dtype: Option, + + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is + /// encouraged when loading a model with custom code to ensure no malicious code has been + /// contributed in a newer revision. + #[clap(long, env, value_enum)] + pub trust_remote_code: bool, + + /// The maximum amount of concurrent requests for this particular deployment. + /// Having a low limit will refuse clients requests instead of having them + /// wait for too long and is usually good to handle backpressure correctly. + #[clap(default_value = "128", long, env)] + pub max_concurrent_requests: usize, + + /// This is the maximum allowed value for clients to set `best_of`. + /// Best of makes `n` generations at the same time, and return the best + /// in terms of overall log probability over the entire generated sequence + #[clap(default_value = "2", long, env)] + pub max_best_of: usize, + + /// This is the maximum allowed value for clients to set `stop_sequences`. + /// Stop sequences are used to allow the model to stop on more than just + /// the EOS token, and enable more complex "prompting" where users can preprompt + /// the model in a specific way and define their "own" stop token aligned with + /// their prompt. + #[clap(default_value = "4", long, env)] + pub max_stop_sequences: usize, + + /// This is the maximum allowed value for clients to set `top_n_tokens`. + /// `top_n_tokens is used to return information about the the `n` most likely + /// tokens at each generation step, instead of just the sampled token. This + /// information can be used for downstream tasks like for classification or + /// ranking. + #[clap(default_value = "5", long, env)] + pub max_top_n_tokens: u32, + + /// This is the maximum allowed input length (expressed in number of tokens) + /// for users. The larger this value, the longer prompt users can send which + /// can impact the overall memory required to handle the load. + /// Please note that some models have a finite range of sequence they can handle. + /// Default to min(max_position_embeddings - 1, 4095) + #[clap(long, env)] + pub max_input_tokens: Option, + + /// Legacy version of [`Args::max_input_tokens`]. + #[clap(long, env)] + pub max_input_length: Option, + + /// This is the most important value to set as it defines the "memory budget" + /// of running clients requests. + /// Clients will send input sequences and ask to generate `max_new_tokens` + /// on top. with a value of `1512` users can send either a prompt of + /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for + /// `1511` max_new_tokens. + /// The larger this value, the larger amount each request will be in your RAM + /// and the less effective batching can be. + /// Default to min(max_position_embeddings, 4096) + #[clap(long, env)] + pub max_total_tokens: Option, + + /// This represents the ratio of waiting queries vs running queries where + /// you want to start considering pausing the running queries to include the waiting + /// ones into the same batch. + /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's + /// only 10 queries left in the current batch we check if we can fit those 12 + /// waiting queries into the batching strategy, and if yes, then batching happens + /// delaying the 10 running queries by a `prefill` run. + /// + /// This setting is only applied if there is room in the batch + /// as defined by `max_batch_total_tokens`. + #[clap(default_value = "0.3", long, env)] + pub waiting_served_ratio: f32, + + /// Limits the number of tokens for the prefill operation. + /// Since this operation take the most memory and is compute bound, it is interesting + /// to limit the number of requests that can be sent. + /// Default to `max_input_tokens + 50` to give a bit of room. + #[clap(long, env)] + pub max_batch_prefill_tokens: Option, + + /// **IMPORTANT** This is one critical control to allow maximum usage + /// of the available hardware. + /// + /// This represents the total amount of potential tokens within a batch. + /// When using padding (not recommended) this would be equivalent of + /// `batch_size` * `max_total_tokens`. + /// + /// However in the non-padded (flash attention) version this can be much finer. + /// + /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` + /// or a single query of `1000` tokens. + /// + /// Overall this number should be the largest possible amount that fits the + /// remaining memory (after the model is loaded). Since the actual memory overhead + /// depends on other parameters like if you're using quantization, flash attention + /// or the model implementation, text-generation-inference cannot infer this number + /// automatically. + #[clap(long, env)] + pub max_batch_total_tokens: Option, + + /// This setting defines how many tokens can be passed before forcing the waiting + /// queries to be put on the batch (if the size of the batch allows for it). + /// New queries require 1 `prefill` forward, which is different from `decode` + /// and therefore you need to pause the running batch in order to run `prefill` + /// to create the correct values for the waiting queries to be able to join the batch. + /// + /// With a value too small, queries will always "steal" the compute to run `prefill` + /// and running queries will be delayed by a lot. + /// + /// With a value too big, waiting queries could wait for a very long time + /// before being allowed a slot in the running batch. If your server is busy + /// that means that requests that could run in ~2s on an empty server could + /// end up running in ~20s because the query had to wait for 18s. + /// + /// This number is expressed in number of tokens to make it a bit more + /// "model" agnostic, but what should really matter is the overall latency + /// for end users. + #[clap(default_value = "20", long, env)] + pub max_waiting_tokens: usize, + + /// Enforce a maximum number of requests per batch + /// Specific flag for hardware targets that do not support unpadded inference + #[clap(long, env)] + pub max_batch_size: Option, + + /// Specify the batch sizes to compute cuda graphs for. + /// Use "0" to disable. + /// Default = "1,2,4,8,16,32" + #[clap(long, env, value_delimiter = ',')] + pub cuda_graphs: Option>, + + /// The IP address to listen on + #[clap(default_value = "0.0.0.0", long, env)] + pub hostname: String, + + /// The port to listen on. + #[clap(default_value = "3000", long, short, env)] + pub port: u16, + + /// The name of the socket for gRPC communication between the webserver + /// and the shards. + #[clap(default_value = "/tmp/text-generation-server", long, env)] + pub shard_uds_path: String, + + /// The address the master shard will listen on. (setting used by torch distributed) + #[clap(default_value = "localhost", long, env)] + pub master_addr: String, + + /// The address the master port will listen on. (setting used by torch distributed) + #[clap(default_value = "29500", long, env)] + pub master_port: usize, + + /// The location of the huggingface hub cache. + /// Used to override the location if you want to provide a mounted disk for instance + #[clap(long, env)] + pub huggingface_hub_cache: Option, + + /// The location of the huggingface hub cache. + /// Used to override the location if you want to provide a mounted disk for instance + #[clap(long, env)] + pub weights_cache_override: Option, + + /// For some models (like bloom), text-generation-inference implemented custom + /// cuda kernels to speed up inference. Those kernels were only tested on A100. + /// Use this flag to disable them if you're running on different hardware and + /// encounter issues. + #[clap(long, env)] + pub disable_custom_kernels: bool, + + /// Limit the CUDA available memory. + /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction. + #[clap(default_value = "1.0", long, env)] + pub cuda_memory_fraction: f32, + + /// Rope scaling will only be used for RoPE models + /// and allow rescaling the position rotary to accomodate for + /// larger prompts. + /// + /// Goes together with `rope_factor`. + /// + /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0 + /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0 + /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed + /// basically) + /// + /// `--rope-scaling linear --rope-factor` fully describes the scaling you want + #[clap(long, env)] + pub rope_scaling: Option, + + /// Rope scaling will only be used for RoPE models + /// See `rope_scaling` + #[clap(long, env)] + pub rope_factor: Option, + + /// Outputs the logs in JSON format (useful for telemetry) + #[clap(long, env)] + pub json_output: bool, + + #[clap(long, env)] + pub otlp_endpoint: Option, + + #[clap(long, env)] + pub cors_allow_origin: Vec, + #[clap(long, env)] + pub watermark_gamma: Option, + #[clap(long, env)] + pub watermark_delta: Option, + + /// Enable ngrok tunneling + #[clap(long, env)] + pub ngrok: bool, + + /// ngrok authentication token + #[clap(long, env)] + pub ngrok_authtoken: Option, + + /// ngrok edge + #[clap(long, env)] + pub ngrok_edge: Option, + + /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may + /// include a `chat_template`. If not provided, the default config will be used from the model hub. + #[clap(long, env)] + pub tokenizer_config_path: Option, + + /// Disable outlines grammar constrained generation. + /// This is a feature that allows you to generate text that follows a specific grammar. + #[clap(long, env)] + pub disable_grammar_support: bool, + + /// Display a lot of information about your runtime environment + #[clap(long, short, action)] + pub env: bool, + + /// Control the maximum number of inputs that a client can send in a single request + #[clap(default_value = "4", long, env)] + pub max_client_batch_size: usize, +} + +#[derive(Debug)] +enum ShardStatus { + Ready, + Failed(usize), +} + +#[allow(clippy::too_many_arguments)] +fn shard_manager( + model_id: String, + revision: Option, + quantize: Option, + speculate: Option, + dtype: Option, + trust_remote_code: bool, + uds_path: String, + rank: usize, + world_size: usize, + master_addr: String, + master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, + disable_custom_kernels: bool, + watermark_gamma: Option, + watermark_delta: Option, + cuda_graphs: Vec, + cuda_memory_fraction: f32, + rope_scaling: Option, + rope_factor: Option, + max_total_tokens: usize, + max_batch_size: Option, + otlp_endpoint: Option, + status_sender: mpsc::Sender, + shutdown: Arc, + _shutdown_sender: mpsc::Sender<()>, +) { + // Enter shard-manager tracing span + let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); + + // Get UDS path + let uds_string = format!("{uds_path}-{rank}"); + let uds = Path::new(&uds_string); + // Clean previous runs + if uds.exists() { + fs::remove_file(uds).unwrap(); + } + + // Process args + let mut shard_args = vec![ + "serve".to_string(), + model_id, + "--uds-path".to_string(), + uds_path, + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + + // Activate trust remote code + if trust_remote_code { + shard_args.push("--trust-remote-code".to_string()); + } + + // Activate tensor parallelism + if world_size > 1 { + shard_args.push("--sharded".to_string()); + } + + if let Some(quantize) = quantize { + shard_args.push("--quantize".to_string()); + shard_args.push(quantize.to_string()) + } + + if let Some(speculate) = speculate { + shard_args.push("--speculate".to_string()); + shard_args.push(speculate.to_string()) + } + + if let Some(dtype) = dtype { + shard_args.push("--dtype".to_string()); + shard_args.push(dtype.to_string()) + } + + // Model optional revision + if let Some(revision) = revision { + shard_args.push("--revision".to_string()); + shard_args.push(revision) + } + + let rope = match (rope_scaling, rope_factor) { + (None, None) => None, + (Some(scaling), None) => Some((scaling, 1.0)), + (Some(scaling), Some(factor)) => Some((scaling, factor)), + (None, Some(factor)) => Some((RopeScaling::Linear, factor)), + }; + + // OpenTelemetry + if let Some(otlp_endpoint) = otlp_endpoint { + shard_args.push("--otlp-endpoint".to_string()); + shard_args.push(otlp_endpoint); + } + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Remove LOG_LEVEL if present + envs.retain(|(name, _)| name != "LOG_LEVEL"); + + // Torch Distributed Env vars + envs.push(("RANK".into(), rank.to_string().into())); + envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); + envs.push(("MASTER_ADDR".into(), master_addr.into())); + envs.push(("MASTER_PORT".into(), master_port.to_string().into())); + envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into())); + + // CUDA memory fraction + envs.push(( + "CUDA_MEMORY_FRACTION".into(), + cuda_memory_fraction.to_string().into(), + )); + + // Safetensors load fast + envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); + + // Disable progress bar + envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); + + // Enable hf transfer for insane download speeds + let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); + envs.push(( + "HF_HUB_ENABLE_HF_TRANSFER".into(), + enable_hf_transfer.into(), + )); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + // Detect rope scaling + // Sending as env instead of CLI args to not bloat everything + // those only can be used by RoPE models, so passing information around + // for all models will complexify code unnecessarily + if let Some((scaling, factor)) = rope { + envs.push(("ROPE_SCALING".into(), scaling.to_string().into())); + envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); + } + + envs.push(( + "MAX_TOTAL_TOKENS".into(), + max_total_tokens.to_string().into(), + )); + if let Some(max_batch_size) = max_batch_size { + envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); + } + + // If huggingface_hub_cache is some, pass it to the shard + // Useful when running inside a docker container + if let Some(huggingface_hub_cache) = huggingface_hub_cache { + envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // If weights_cache_override is some, pass it to the shard + // Useful when running inside a HuggingFace Inference Endpoint + if let Some(weights_cache_override) = weights_cache_override { + envs.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); + }; + + // Enable experimental support for cuda graphs + if !cuda_graphs.is_empty() { + envs.push(( + "CUDA_GRAPHS".into(), + cuda_graphs + .into_iter() + .map(|c| c.to_string()) + .collect::>() + .join(",") + .into(), + )); + } + + // If disable_custom_kernels is true, pass it to the shard as an env var + if disable_custom_kernels { + envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) + } + + // Watermark Gamma + if let Some(watermark_gamma) = watermark_gamma { + envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) + } + + // Watermark Delta + if let Some(watermark_delta) = watermark_delta { + envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) + } + + // Start process + tracing::info!("Starting shard"); + let mut p = match Command::new("text-generation-server") + .args(shard_args) + .env_clear() + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(err) => { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } + { + tracing::error!("{}", err); + } + + status_sender.send(ShardStatus::Failed(rank)).unwrap(); + return; + } + }; + + // Redirect STDOUT to the console + let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); + let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); + + //stdout tracing thread + thread::spawn(move || { + log_lines(shard_stdout_reader.lines()); + }); + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in shard_stderr_reader.lines().map_while(Result::ok) { + err_sender.send(line).unwrap_or(()); + } + }); + + let mut ready = false; + let start_time = Instant::now(); + let mut wait_time = Instant::now(); + loop { + // Process exited + if let Some(exit_status) = p.try_wait().unwrap() { + let mut err = String::new(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } + + tracing::error!("Shard complete standard error output:\n{err}"); + + if let Some(signal) = exit_status.signal() { + tracing::error!("Shard process was signaled to shutdown with signal {signal}"); + } + + status_sender.send(ShardStatus::Failed(rank)).unwrap(); + return; + } + + // We received a shutdown signal + if shutdown.load(Ordering::SeqCst) { + terminate("shard", p, Duration::from_secs(90)).unwrap(); + return; + } + + // Shard is ready + if uds.exists() && !ready { + tracing::info!("Shard ready in {:?}", start_time.elapsed()); + status_sender.send(ShardStatus::Ready).unwrap(); + ready = true; + } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { + tracing::info!("Waiting for shard to be ready..."); + wait_time = Instant::now(); + } + sleep(Duration::from_millis(100)); + } +} + +fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>) { + tracing::info!("Shutting down shards"); + // Update shutdown value to true + // This will be picked up by the shard manager + shutdown.store(true, Ordering::SeqCst); + + // Wait for shards to shutdown + // This will block till all shutdown_sender are dropped + let _ = shutdown_receiver.recv(); +} + +fn num_cuda_devices() -> Option { + let devices = match env::var("CUDA_VISIBLE_DEVICES") { + Ok(devices) => devices, + Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, + }; + let n_devices = devices.split(',').count(); + Some(n_devices) +} + +#[derive(Deserialize)] +#[serde(rename_all = "UPPERCASE")] +enum PythonLogLevelEnum { + Trace, + Debug, + Info, + Success, + Warning, + Error, + Critical, +} + +#[derive(Deserialize)] +struct PythonLogLevel { + name: PythonLogLevelEnum, +} + +#[derive(Deserialize)] +struct PythonLogRecord { + level: PythonLogLevel, +} + +#[derive(Deserialize)] +struct PythonLogMessage { + text: String, + record: PythonLogRecord, +} + +impl PythonLogMessage { + fn trace(&self) { + match self.record.level.name { + PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), + PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), + PythonLogLevelEnum::Info => tracing::info!("{}", self.text), + PythonLogLevelEnum::Success => tracing::info!("{}", self.text), + PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), + PythonLogLevelEnum::Error => tracing::error!("{}", self.text), + PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), + } + } +} + +impl TryFrom<&String> for PythonLogMessage { + type Error = serde_json::Error; + + fn try_from(value: &String) -> Result { + serde_json::from_str::(value) + } +} + +fn log_lines(lines: Lines) { + for line in lines.map_while(Result::ok) { + match PythonLogMessage::try_from(&line) { + Ok(log) => log.trace(), + Err(_) => tracing::debug!("{line}"), + } + } +} + +fn find_num_shards( + sharded: Option, + num_shard: Option, +) -> Result { + // get the number of shards given `sharded` and `num_shard` + let num_shard = match (sharded, num_shard) { + (Some(true), None) => { + // try to default to the number of available GPUs + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); + let n_devices = num_cuda_devices() + .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); + if n_devices <= 1 { + return Err(LauncherError::NotEnoughCUDADevices(format!( + "`sharded` is true but only found {n_devices} CUDA devices" + ))); + } + n_devices + } + (Some(true), Some(num_shard)) => { + // we can't have only one shard while sharded + if num_shard <= 1 { + return Err(LauncherError::ArgumentValidation( + "`sharded` is true but `num_shard` <= 1".to_string(), + )); + } + num_shard + } + (Some(false), Some(num_shard)) => num_shard, + (Some(false), None) => 1, + (None, None) => num_cuda_devices().unwrap_or(1), + (None, Some(num_shard)) => num_shard, + }; + if num_shard < 1 { + return Err(LauncherError::ArgumentValidation( + "`num_shard` cannot be < 1".to_string(), + )); + } + Ok(num_shard) +} + +#[derive(Debug, Error)] +pub enum LauncherError { + #[error("Invalid argument: {0}")] + ArgumentValidation(String), + #[error("not enough cuda devices: {0}")] + NotEnoughCUDADevices(String), + #[error("Download error")] + DownloadError, + #[error("Shard cannot start")] + ShardCannotStart, + #[error("Shard disconnected")] + ShardDisconnected, + #[error("Shard failed")] + ShardFailed, + #[error("Webserver failed")] + WebserverFailed, + #[error("Webserver cannot start")] + WebserverCannotStart, +} + +fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { + // Enter download tracing span + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + + let mut download_args = vec![ + "download-weights".to_string(), + args.model_id.to_string(), + "--extension".to_string(), + ".safetensors".to_string(), + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + + // Model optional revision + if let Some(revision) = &args.revision { + download_args.push("--revision".to_string()); + download_args.push(revision.to_string()) + } + + // Trust remote code for automatic peft fusion + if args.trust_remote_code { + download_args.push("--trust-remote-code".to_string()); + } + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Remove LOG_LEVEL if present + envs.retain(|(name, _)| name != "LOG_LEVEL"); + + // Disable progress bar + envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); + + // If huggingface_hub_cache is set, pass it to the download process + // Useful when running inside a docker container + if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { + envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // Enable hf transfer for insane download speeds + let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); + envs.push(( + "HF_HUB_ENABLE_HF_TRANSFER".into(), + enable_hf_transfer.into(), + )); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + // If args.weights_cache_override is some, pass it to the download process + // Useful when running inside a HuggingFace Inference Endpoint + if let Some(weights_cache_override) = &args.weights_cache_override { + envs.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); + }; + + // Start process + tracing::info!("Starting download process."); + let mut download_process = match Command::new("text-generation-server") + .args(download_args) + .env_clear() + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(err) => { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } else { + tracing::error!("{}", err); + } + + return Err(LauncherError::DownloadError); + } + }; + + let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); + + thread::spawn(move || { + log_lines(download_stdout.lines()); + }); + + let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); + + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in download_stderr.lines().map_while(Result::ok) { + err_sender.send(line).unwrap_or(()); + } + }); + + loop { + if let Some(status) = download_process.try_wait().unwrap() { + if status.success() { + tracing::info!("Successfully downloaded weights."); + break; + } + + let mut err = String::new(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } + + if let Some(signal) = status.signal() { + tracing::error!( + "Download process was signaled to shutdown with signal {signal}: {err}" + ); + } else { + tracing::error!("Download encountered an error: {err}"); + } + + return Err(LauncherError::DownloadError); + } + if !running.load(Ordering::SeqCst) { + terminate("download", download_process, Duration::from_secs(10)).unwrap(); + return Ok(()); + } + sleep(Duration::from_millis(100)); + } + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn spawn_shards( + num_shard: usize, + args: &Args, + cuda_graphs: Vec, + max_total_tokens: usize, + shutdown: Arc, + shutdown_receiver: &mpsc::Receiver<()>, + shutdown_sender: mpsc::Sender<()>, + status_receiver: &mpsc::Receiver, + status_sender: mpsc::Sender, + running: Arc, +) -> Result<(), LauncherError> { + // Start shard processes + for rank in 0..num_shard { + let model_id = args.model_id.clone(); + let revision = args.revision.clone(); + let uds_path = args.shard_uds_path.clone(); + let master_addr = args.master_addr.clone(); + let huggingface_hub_cache = args.huggingface_hub_cache.clone(); + let weights_cache_override = args.weights_cache_override.clone(); + let status_sender = status_sender.clone(); + let shutdown = shutdown.clone(); + let shutdown_sender = shutdown_sender.clone(); + let otlp_endpoint = args.otlp_endpoint.clone(); + let quantize = args.quantize; + let speculate = args.speculate; + let dtype = args.dtype; + let trust_remote_code = args.trust_remote_code; + let master_port = args.master_port; + let disable_custom_kernels = args.disable_custom_kernels; + let watermark_gamma = args.watermark_gamma; + let watermark_delta = args.watermark_delta; + let cuda_graphs_clone = cuda_graphs.clone(); + let cuda_memory_fraction = args.cuda_memory_fraction; + let rope_scaling = args.rope_scaling; + let rope_factor = args.rope_factor; + let max_batch_size = args.max_batch_size; + thread::spawn(move || { + shard_manager( + model_id, + revision, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + rank, + num_shard, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + watermark_gamma, + watermark_delta, + cuda_graphs_clone, + cuda_memory_fraction, + rope_scaling, + rope_factor, + max_total_tokens, + max_batch_size, + otlp_endpoint, + status_sender, + shutdown, + shutdown_sender, + ) + }); + } + drop(shutdown_sender); + + // Wait for shard to start + let mut shard_ready = 0; + while running.load(Ordering::SeqCst) { + match status_receiver.try_recv() { + Ok(ShardStatus::Ready) => { + shard_ready += 1; + if shard_ready == num_shard { + break; + } + } + Err(TryRecvError::Empty) => { + sleep(Duration::from_millis(100)); + } + Ok(ShardStatus::Failed(rank)) => { + tracing::error!("Shard {rank} failed to start"); + shutdown_shards(shutdown, shutdown_receiver); + return Err(LauncherError::ShardCannotStart); + } + Err(TryRecvError::Disconnected) => { + tracing::error!("Shard status channel disconnected"); + shutdown_shards(shutdown, shutdown_receiver); + return Err(LauncherError::ShardDisconnected); + } + } + } + Ok(()) +} + +fn compute_type(num_shard: usize) -> Option { + let output = Command::new("nvidia-smi") + .args(["--query-gpu=gpu_name", "--format=csv"]) + .output() + .ok()?; + let output = String::from_utf8(output.stdout).ok()?; + let fullname = output.split('\n').nth(1)?; + let cardname = fullname.replace(' ', "-").to_lowercase(); + let compute_type = format!("{num_shard}-{cardname}"); + Some(compute_type) +} + +fn spawn_webserver( + num_shard: usize, + args: Args, + max_input_tokens: usize, + max_total_tokens: usize, + max_batch_prefill_tokens: u32, + shutdown: Arc, + shutdown_receiver: &mpsc::Receiver<()>, +) -> Result { + // All shard started + // Start webserver + tracing::info!("Starting Webserver"); + let mut router_args = vec![ + "--max-client-batch-size".to_string(), + args.max_client_batch_size.to_string(), + "--max-concurrent-requests".to_string(), + args.max_concurrent_requests.to_string(), + "--max-best-of".to_string(), + args.max_best_of.to_string(), + "--max-stop-sequences".to_string(), + args.max_stop_sequences.to_string(), + "--max-top-n-tokens".to_string(), + args.max_top_n_tokens.to_string(), + "--max-input-tokens".to_string(), + max_input_tokens.to_string(), + "--max-total-tokens".to_string(), + max_total_tokens.to_string(), + "--max-batch-prefill-tokens".to_string(), + max_batch_prefill_tokens.to_string(), + "--waiting-served-ratio".to_string(), + args.waiting_served_ratio.to_string(), + "--max-waiting-tokens".to_string(), + args.max_waiting_tokens.to_string(), + "--validation-workers".to_string(), + args.validation_workers.to_string(), + "--hostname".to_string(), + args.hostname.to_string(), + "--port".to_string(), + args.port.to_string(), + "--master-shard-uds-path".to_string(), + format!("{}-0", args.shard_uds_path), + "--tokenizer-name".to_string(), + args.model_id, + ]; + + // Grammar support + if args.disable_grammar_support { + router_args.push("--disable-grammar-support".to_string()); + } + + // Tokenizer config path + if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { + router_args.push("--tokenizer-config-path".to_string()); + router_args.push(tokenizer_config_path.to_string()); + } + + // Model optional max batch total tokens + if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { + router_args.push("--max-batch-total-tokens".to_string()); + router_args.push(max_batch_total_tokens.to_string()); + } + + // Router optional max batch size + if let Some(max_batch_size) = args.max_batch_size { + router_args.push("--max-batch-size".to_string()); + router_args.push(max_batch_size.to_string()); + } + + // Model optional revision + if let Some(ref revision) = args.revision { + router_args.push("--revision".to_string()); + router_args.push(revision.to_string()) + } + + if args.json_output { + router_args.push("--json-output".to_string()); + } + + // OpenTelemetry + if let Some(otlp_endpoint) = args.otlp_endpoint { + router_args.push("--otlp-endpoint".to_string()); + router_args.push(otlp_endpoint); + } + + // CORS origins + for origin in args.cors_allow_origin.into_iter() { + router_args.push("--cors-allow-origin".to_string()); + router_args.push(origin); + } + + // Ngrok + if args.ngrok { + router_args.push("--ngrok".to_string()); + router_args.push("--ngrok-authtoken".to_string()); + router_args.push(args.ngrok_authtoken.unwrap()); + router_args.push("--ngrok-edge".to_string()); + router_args.push(args.ngrok_edge.unwrap()); + } + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + // Parse Compute type + if let Ok(compute_type) = env::var("COMPUTE_TYPE") { + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + } else if let Some(compute_type) = compute_type(num_shard) { + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + } + + let mut webserver = match Command::new("text-generation-router") + .args(router_args) + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(err) => { + tracing::error!("Failed to start webserver: {}", err); + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-router not found in PATH"); + tracing::error!("Please install it with `make install-router`") + } else { + tracing::error!("{}", err); + } + + shutdown_shards(shutdown, shutdown_receiver); + return Err(LauncherError::WebserverCannotStart); + } + }; + + // Redirect STDOUT and STDERR to the console + let webserver_stdout = webserver.stdout.take().unwrap(); + let webserver_stderr = webserver.stderr.take().unwrap(); + + thread::spawn(move || { + let stdout = BufReader::new(webserver_stdout); + let stderr = BufReader::new(webserver_stderr); + for line in stdout.lines() { + println!("{}", line.unwrap()); + } + for line in stderr.lines() { + println!("{}", line.unwrap()); + } + }); + Ok(webserver) +} + +fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result { + tracing::info!("Terminating {process_name}"); + + let terminate_time = Instant::now(); + signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); + + tracing::info!("Waiting for {process_name} to gracefully shutdown"); + while terminate_time.elapsed() < timeout { + if let Some(status) = process.try_wait()? { + tracing::info!("{process_name} terminated"); + return Ok(status); + } + sleep(Duration::from_millis(100)); + } + tracing::info!("Killing {process_name}"); + + process.kill()?; + let exit_status = process.wait()?; + + tracing::info!("{process_name} killed"); + Ok(exit_status) +} + +#[allow(clippy::too_many_arguments)] +pub fn launcher_main( + model_id: String, + revision: Option, + validation_workers: usize, + sharded: Option, + num_shard: Option, + quantize: Option, + speculate: Option, + dtype: Option, + trust_remote_code: bool, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: Option, + max_input_length: Option, + max_total_tokens: Option, + waiting_served_ratio: f32, + max_batch_prefill_tokens: Option, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + cuda_graphs: Option>, + hostname: String, + port: u16, + shard_uds_path: String, + master_addr: String, + master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, + disable_custom_kernels: bool, + cuda_memory_fraction: f32, + rope_scaling: Option, + rope_factor: Option, + json_output: bool, + otlp_endpoint: Option, + cors_allow_origin: Vec, + watermark_gamma: Option, + watermark_delta: Option, + ngrok: bool, + ngrok_authtoken: Option, + ngrok_edge: Option, + tokenizer_config_path: Option, + disable_grammar_support: bool, + env: bool, + max_client_batch_size: usize, +) -> Result<(), LauncherError> { + let args = Args { + model_id, + revision, + validation_workers, + sharded, + num_shard, + quantize, + speculate, + dtype, + trust_remote_code, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + cuda_graphs, + hostname, + port, + shard_uds_path, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + cuda_memory_fraction, + rope_scaling, + rope_factor, + json_output, + otlp_endpoint, + cors_allow_origin, + watermark_gamma, + watermark_delta, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config_path, + disable_grammar_support, + env, + max_client_batch_size, + }; + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + + if args.json_output { + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .json() + .init(); + } else { + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .compact() + .init(); + } + + if args.env { + let env_runtime = env_runtime::Env::new(); + tracing::info!("{}", env_runtime); + } + + tracing::info!("{:#?}", args); + + let get_max_position_embeddings = || -> Result> { + let model_id = args.model_id.clone(); + let mut path = std::path::Path::new(&args.model_id).to_path_buf(); + let filename = if !path.exists() { + // Assume it's a hub id + let api = Api::new()?; + let repo = if let Some(ref revision) = args.revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path + }; + + let content = std::fs::read_to_string(filename)?; + let config: Config = serde_json::from_str(&content)?; + + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) { + (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + max_default + } else { + max_position_embeddings + } + } + _ => { + return Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))); + } + }; + Ok(max_position_embeddings) + }; + let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); + + let max_input_tokens = { + match (args.max_input_tokens, args.max_input_length) { + (Some(max_input_tokens), Some(max_input_length)) => { + return Err(LauncherError::ArgumentValidation( + format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.", + ))); + } + (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, + (None, None) => { + let value = max_position_embeddings - 1; + tracing::info!("Default `max_input_tokens` to {value}"); + value + } + } + }; + let max_total_tokens = { + match args.max_total_tokens { + Some(max_total_tokens) => max_total_tokens, + None => { + let value = max_position_embeddings; + tracing::info!("Default `max_total_tokens` to {value}"); + value + } + } + }; + let max_batch_prefill_tokens = { + match args.max_batch_prefill_tokens { + Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, + None => { + let value: u32 = if let Some(max_batch_size) = args.max_batch_size { + max_batch_size * max_input_tokens + } else { + // Adding some edge in order to account for potential block_size alignement + // issue. + max_input_tokens + 50 + } as u32; + tracing::info!("Default `max_batch_prefill_tokens` to {value}"); + value + } + } + }; + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(LauncherError::ArgumentValidation( + "`max_input_tokens must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", + max_batch_prefill_tokens, max_input_tokens + ))); + } + + let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { + (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), + #[allow(deprecated)] + ( + None, + Some( + Quantization::Bitsandbytes + | Quantization::BitsandbytesNF4 + | Quantization::BitsandbytesFP4, + ), + ) => { + tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + vec![] + } + _ => { + let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; + tracing::info!("Using default cuda graphs {cuda_graphs:?}"); + cuda_graphs + } + }; + + if args.validation_workers == 0 { + return Err(LauncherError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + if args.trust_remote_code { + tracing::warn!( + "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", + args.model_id + ); + } + + let num_shard = find_num_shards(args.sharded, args.num_shard)?; + if num_shard > 1 { + tracing::info!("Sharding model on {num_shard} processes"); + } + + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + max_batch_prefill_tokens, max_batch_total_tokens + ))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + max_total_tokens, max_batch_total_tokens + ))); + } + } + + if args.ngrok { + if args.ngrok_authtoken.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), + )); + } + + if args.ngrok_edge.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-edge` must be set when using ngrok tunneling".to_string(), + )); + } + } + + // Signal handler + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + // Download and convert model weights + download_convert_model(&args, running.clone())?; + + if !running.load(Ordering::SeqCst) { + // Launcher was asked to stop + return Ok(()); + } + + // Shared shutdown bool + let shutdown = Arc::new(AtomicBool::new(false)); + // Shared shutdown channel + // When shutting down, the main thread will wait for all senders to be dropped + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + + // Shared channel to track shard status + let (status_sender, status_receiver) = mpsc::channel(); + + spawn_shards( + num_shard, + &args, + cuda_graphs, + max_total_tokens, + shutdown.clone(), + &shutdown_receiver, + shutdown_sender, + &status_receiver, + status_sender, + running.clone(), + )?; + + // We might have received a termination signal + if !running.load(Ordering::SeqCst) { + shutdown_shards(shutdown, &shutdown_receiver); + return Ok(()); + } + + let mut webserver = spawn_webserver( + num_shard, + args, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + shutdown.clone(), + &shutdown_receiver, + ) + .map_err(|err| { + shutdown_shards(shutdown.clone(), &shutdown_receiver); + err + })?; + + // Default exit code + let mut exit_code = Ok(()); + + while running.load(Ordering::SeqCst) { + if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { + tracing::error!("Shard {rank} crashed"); + exit_code = Err(LauncherError::ShardFailed); + break; + }; + + match webserver.try_wait().unwrap() { + Some(_) => { + tracing::error!("Webserver Crashed"); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::WebserverFailed); + } + None => { + sleep(Duration::from_millis(100)); + } + }; + } + + // Graceful termination + terminate("webserver", webserver, Duration::from_secs(90)).unwrap(); + shutdown_shards(shutdown, &shutdown_receiver); + + exit_code +} + +#[allow(clippy::too_many_arguments)] +pub fn launcher_main_without_server( + model_id: String, + revision: Option, + validation_workers: usize, + sharded: Option, + num_shard: Option, + quantize: Option, + speculate: Option, + dtype: Option, + trust_remote_code: bool, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: Option, + max_input_length: Option, + max_total_tokens: Option, + waiting_served_ratio: f32, + max_batch_prefill_tokens: Option, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + cuda_graphs: Option>, + hostname: String, + port: u16, + shard_uds_path: String, + master_addr: String, + master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, + disable_custom_kernels: bool, + cuda_memory_fraction: f32, + rope_scaling: Option, + rope_factor: Option, + json_output: bool, + otlp_endpoint: Option, + cors_allow_origin: Vec, + watermark_gamma: Option, + watermark_delta: Option, + ngrok: bool, + ngrok_authtoken: Option, + ngrok_edge: Option, + tokenizer_config_path: Option, + disable_grammar_support: bool, + env: bool, + max_client_batch_size: usize, + webserver_callback: Box Result<(), LauncherError>>, +) -> Result<(), LauncherError> { + let args = Args { + model_id, + revision, + validation_workers, + sharded, + num_shard, + quantize, + speculate, + dtype, + trust_remote_code, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + cuda_graphs, + hostname, + port, + shard_uds_path, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + cuda_memory_fraction, + rope_scaling, + rope_factor, + json_output, + otlp_endpoint, + cors_allow_origin, + watermark_gamma, + watermark_delta, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config_path, + disable_grammar_support, + env, + max_client_batch_size, + }; + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + + if args.json_output { + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .json() + .init(); + } else { + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .compact() + .init(); + } + + if args.env { + let env_runtime = env_runtime::Env::new(); + tracing::info!("{}", env_runtime); + } + + tracing::info!("{:#?}", args); + + let get_max_position_embeddings = || -> Result> { + let model_id = args.model_id.clone(); + let mut path = std::path::Path::new(&args.model_id).to_path_buf(); + let filename = if !path.exists() { + // Assume it's a hub id + let api = Api::new()?; + let repo = if let Some(ref revision) = args.revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path + }; + + let content = std::fs::read_to_string(filename)?; + let config: Config = serde_json::from_str(&content)?; + + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) { + (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + max_default + } else { + max_position_embeddings + } + } + _ => { + return Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))); + } + }; + Ok(max_position_embeddings) + }; + let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); + + let max_input_tokens = { + match (args.max_input_tokens, args.max_input_length) { + (Some(max_input_tokens), Some(max_input_length)) => { + return Err(LauncherError::ArgumentValidation( + format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.", + ))); + } + (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, + (None, None) => { + let value = max_position_embeddings - 1; + tracing::info!("Default `max_input_tokens` to {value}"); + value + } + } + }; + let max_total_tokens = { + match args.max_total_tokens { + Some(max_total_tokens) => max_total_tokens, + None => { + let value = max_position_embeddings; + tracing::info!("Default `max_total_tokens` to {value}"); + value + } + } + }; + let max_batch_prefill_tokens = { + match args.max_batch_prefill_tokens { + Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, + None => { + let value: u32 = if let Some(max_batch_size) = args.max_batch_size { + max_batch_size * max_input_tokens + } else { + // Adding some edge in order to account for potential block_size alignement + // issue. + max_input_tokens + 50 + } as u32; + tracing::info!("Default `max_batch_prefill_tokens` to {value}"); + value + } + } + }; + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(LauncherError::ArgumentValidation( + "`max_input_tokens must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", + max_batch_prefill_tokens, max_input_tokens + ))); + } + + let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { + (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), + #[allow(deprecated)] + ( + None, + Some( + Quantization::Bitsandbytes + | Quantization::BitsandbytesNF4 + | Quantization::BitsandbytesFP4, + ), + ) => { + tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + vec![] + } + _ => { + let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; + tracing::info!("Using default cuda graphs {cuda_graphs:?}"); + cuda_graphs + } + }; + + if args.validation_workers == 0 { + return Err(LauncherError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + if args.trust_remote_code { + tracing::warn!( + "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", + args.model_id + ); + } + + let num_shard = find_num_shards(args.sharded, args.num_shard)?; + if num_shard > 1 { + tracing::info!("Sharding model on {num_shard} processes"); + } + + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + max_batch_prefill_tokens, max_batch_total_tokens + ))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + max_total_tokens, max_batch_total_tokens + ))); + } + } + + if args.ngrok { + if args.ngrok_authtoken.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), + )); + } + + if args.ngrok_edge.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-edge` must be set when using ngrok tunneling".to_string(), + )); + } + } + + // Signal handler + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + // Download and convert model weights + download_convert_model(&args, running.clone())?; + + if !running.load(Ordering::SeqCst) { + // Launcher was asked to stop + return Ok(()); + } + + // Shared shutdown bool + let shutdown = Arc::new(AtomicBool::new(false)); + // Shared shutdown channel + // When shutting down, the main thread will wait for all senders to be dropped + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + + // Shared channel to track shard status + let (status_sender, status_receiver) = mpsc::channel(); + + spawn_shards( + num_shard, + &args, + cuda_graphs, + max_total_tokens, + shutdown.clone(), + &shutdown_receiver, + shutdown_sender, + &status_receiver, + status_sender, + running.clone(), + )?; + + // We might have received a termination signal + if !running.load(Ordering::SeqCst) { + shutdown_shards(shutdown, &shutdown_receiver); + return Ok(()); + } + + // let mut webserver = spawn_webserver( + // num_shard, + // args, + // max_input_tokens, + // max_total_tokens, + // max_batch_prefill_tokens, + // shutdown.clone(), + // &shutdown_receiver, + // ) + // .map_err(|err| { + // shutdown_shards(shutdown.clone(), &shutdown_receiver); + // err + // })?; + + webserver_callback()?; + + println!("Webserver started"); + + // Default exit code + let mut exit_code = Ok(()); + + while running.load(Ordering::SeqCst) { + if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { + tracing::error!("Shard {rank} crashed"); + exit_code = Err(LauncherError::ShardFailed); + break; + }; + + // match webserver.try_wait().unwrap() { + // Some(_) => { + // tracing::error!("Webserver Crashed"); + // shutdown_shards(shutdown, &shutdown_receiver); + // return Err(LauncherError::WebserverFailed); + // } + // None => { + // sleep(Duration::from_millis(100)); + // } + // }; + } + + // Graceful termination + // terminate("webserver", webserver, Duration::from_secs(90)).unwrap(); + shutdown_shards(shutdown, &shutdown_receiver); + + exit_code +} diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a97a75c0..b9113478 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,1582 +1,53 @@ -use clap::{Parser, ValueEnum}; -use hf_hub::{api::sync::Api, Repo, RepoType}; -use nix::sys::signal::{self, Signal}; -use nix::unistd::Pid; -use serde::Deserialize; -use std::env; -use std::ffi::OsString; -use std::io::{BufRead, BufReader, Lines}; -use std::os::unix::process::{CommandExt, ExitStatusExt}; -use std::path::Path; -use std::process::{Child, Command, ExitStatus, Stdio}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc::TryRecvError; -use std::sync::{mpsc, Arc}; -use std::thread; -use std::thread::sleep; -use std::time::{Duration, Instant}; -use std::{fs, io}; -use thiserror::Error; -use tracing_subscriber::{filter::LevelFilter, EnvFilter}; - -mod env_runtime; - -#[derive(Deserialize)] -struct RawConfig { - max_position_embeddings: Option, - n_positions: Option, - max_seq_len: Option, -} - -#[derive(Deserialize)] -struct Config { - max_position_embeddings: Option, -} - -impl From for Config { - fn from(other: RawConfig) -> Self { - let max_position_embeddings = other - .max_position_embeddings - .or(other.max_seq_len) - .or(other.n_positions); - Config { - max_position_embeddings, - } - } -} - -#[derive(Clone, Copy, Debug, ValueEnum)] -enum Quantization { - /// 4 bit quantization. Requires a specific AWQ quantized model: - /// . - /// Should replace GPTQ models wherever possible because of the better latency - Awq, - /// 8 bit quantization, doesn't require specific model. - /// Should be a drop-in replacement to bitsandbytes with much better performance. - /// Kernels are from - Eetq, - /// 4 bit quantization. Requires a specific GTPQ quantized model: . - /// text-generation-inference will use exllama (faster) kernels wherever possible, and use - /// triton kernel (wider support) when it's not. - /// AWQ has faster kernels. - Gptq, - /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, - /// but it is known that the model will be much slower to run than the native f16. - #[deprecated( - since = "1.1.0", - note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" - )] - Bitsandbytes, - /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, - /// but it is known that the model will be much slower to run than the native f16. - BitsandbytesNF4, - /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better - /// perplexity performance for you model - BitsandbytesFP4, - /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above - /// This dtype has native ops should be the fastest if available. - /// This is currently not the fastest because of local unpacking + padding to satisfy matrix - /// multiplication limitations. - Fp8, -} - -impl std::fmt::Display for Quantization { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // To keep in track with `server`. - match self { - #[allow(deprecated)] - // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases - Quantization::Bitsandbytes => { - write!(f, "bitsandbytes") - } - Quantization::BitsandbytesNF4 => { - write!(f, "bitsandbytes-nf4") - } - Quantization::BitsandbytesFP4 => { - write!(f, "bitsandbytes-fp4") - } - Quantization::Gptq => { - write!(f, "gptq") - } - Quantization::Awq => { - write!(f, "awq") - } - Quantization::Eetq => { - write!(f, "eetq") - } - Quantization::Fp8 => { - write!(f, "fp8") - } - } - } -} - -#[derive(Clone, Copy, Debug, ValueEnum)] -enum Dtype { - Float16, - #[clap(name = "bfloat16")] - BFloat16, -} - -impl std::fmt::Display for Dtype { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // To keep in track with `server`. - match self { - Dtype::Float16 => { - write!(f, "float16") - } - Dtype::BFloat16 => { - write!(f, "bfloat16") - } - } - } -} - -#[derive(Clone, Copy, Debug, ValueEnum)] -enum RopeScaling { - Linear, - Dynamic, -} - -impl std::fmt::Display for RopeScaling { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // To keep in track with `server`. - match self { - RopeScaling::Linear => { - write!(f, "linear") - } - RopeScaling::Dynamic => { - write!(f, "dynamic") - } - } - } -} - -/// App Configuration -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct Args { - /// The name of the model to load. - /// Can be a MODEL_ID as listed on like - /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`. - /// Or it can be a local directory containing the necessary files - /// as saved by `save_pretrained(...)` methods of transformers - #[clap(default_value = "bigscience/bloom-560m", long, env)] - model_id: String, - - /// The actual revision of the model if you're referring to a model - /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. - #[clap(long, env)] - revision: Option, - - /// The number of tokenizer workers used for payload validation and truncation inside the - /// router. - #[clap(default_value = "2", long, env)] - validation_workers: usize, - - /// Whether to shard the model across multiple GPUs - /// By default text-generation-inference will use all available GPUs to run - /// the model. Setting it to `false` deactivates `num_shard`. - #[clap(long, env)] - sharded: Option, - - /// The number of shards to use if you don't want to use all GPUs on a given machine. - /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2` - /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to - /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance. - #[clap(long, env)] - num_shard: Option, - - /// Whether you want the model to be quantized. - #[clap(long, env, value_enum)] - quantize: Option, - - /// The number of input_ids to speculate on - /// If using a medusa model, the heads will be picked up automatically - /// Other wise, it will use n-gram speculation which is relatively free - /// in terms of compute, but the speedup heavily depends on the task. - #[clap(long, env)] - speculate: Option, - - /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. - #[clap(long, env, value_enum)] - dtype: Option, - - /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is - /// encouraged when loading a model with custom code to ensure no malicious code has been - /// contributed in a newer revision. - #[clap(long, env, value_enum)] - trust_remote_code: bool, - - /// The maximum amount of concurrent requests for this particular deployment. - /// Having a low limit will refuse clients requests instead of having them - /// wait for too long and is usually good to handle backpressure correctly. - #[clap(default_value = "128", long, env)] - max_concurrent_requests: usize, - - /// This is the maximum allowed value for clients to set `best_of`. - /// Best of makes `n` generations at the same time, and return the best - /// in terms of overall log probability over the entire generated sequence - #[clap(default_value = "2", long, env)] - max_best_of: usize, - - /// This is the maximum allowed value for clients to set `stop_sequences`. - /// Stop sequences are used to allow the model to stop on more than just - /// the EOS token, and enable more complex "prompting" where users can preprompt - /// the model in a specific way and define their "own" stop token aligned with - /// their prompt. - #[clap(default_value = "4", long, env)] - max_stop_sequences: usize, - - /// This is the maximum allowed value for clients to set `top_n_tokens`. - /// `top_n_tokens is used to return information about the the `n` most likely - /// tokens at each generation step, instead of just the sampled token. This - /// information can be used for downstream tasks like for classification or - /// ranking. - #[clap(default_value = "5", long, env)] - max_top_n_tokens: u32, - - /// This is the maximum allowed input length (expressed in number of tokens) - /// for users. The larger this value, the longer prompt users can send which - /// can impact the overall memory required to handle the load. - /// Please note that some models have a finite range of sequence they can handle. - /// Default to min(max_position_embeddings - 1, 4095) - #[clap(long, env)] - max_input_tokens: Option, - - /// Legacy version of [`Args::max_input_tokens`]. - #[clap(long, env)] - max_input_length: Option, - - /// This is the most important value to set as it defines the "memory budget" - /// of running clients requests. - /// Clients will send input sequences and ask to generate `max_new_tokens` - /// on top. with a value of `1512` users can send either a prompt of - /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for - /// `1511` max_new_tokens. - /// The larger this value, the larger amount each request will be in your RAM - /// and the less effective batching can be. - /// Default to min(max_position_embeddings, 4096) - #[clap(long, env)] - max_total_tokens: Option, - - /// This represents the ratio of waiting queries vs running queries where - /// you want to start considering pausing the running queries to include the waiting - /// ones into the same batch. - /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's - /// only 10 queries left in the current batch we check if we can fit those 12 - /// waiting queries into the batching strategy, and if yes, then batching happens - /// delaying the 10 running queries by a `prefill` run. - /// - /// This setting is only applied if there is room in the batch - /// as defined by `max_batch_total_tokens`. - #[clap(default_value = "0.3", long, env)] - waiting_served_ratio: f32, - - /// Limits the number of tokens for the prefill operation. - /// Since this operation take the most memory and is compute bound, it is interesting - /// to limit the number of requests that can be sent. - /// Default to `max_input_tokens + 50` to give a bit of room. - #[clap(long, env)] - max_batch_prefill_tokens: Option, - - /// **IMPORTANT** This is one critical control to allow maximum usage - /// of the available hardware. - /// - /// This represents the total amount of potential tokens within a batch. - /// When using padding (not recommended) this would be equivalent of - /// `batch_size` * `max_total_tokens`. - /// - /// However in the non-padded (flash attention) version this can be much finer. - /// - /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` - /// or a single query of `1000` tokens. - /// - /// Overall this number should be the largest possible amount that fits the - /// remaining memory (after the model is loaded). Since the actual memory overhead - /// depends on other parameters like if you're using quantization, flash attention - /// or the model implementation, text-generation-inference cannot infer this number - /// automatically. - #[clap(long, env)] - max_batch_total_tokens: Option, - - /// This setting defines how many tokens can be passed before forcing the waiting - /// queries to be put on the batch (if the size of the batch allows for it). - /// New queries require 1 `prefill` forward, which is different from `decode` - /// and therefore you need to pause the running batch in order to run `prefill` - /// to create the correct values for the waiting queries to be able to join the batch. - /// - /// With a value too small, queries will always "steal" the compute to run `prefill` - /// and running queries will be delayed by a lot. - /// - /// With a value too big, waiting queries could wait for a very long time - /// before being allowed a slot in the running batch. If your server is busy - /// that means that requests that could run in ~2s on an empty server could - /// end up running in ~20s because the query had to wait for 18s. - /// - /// This number is expressed in number of tokens to make it a bit more - /// "model" agnostic, but what should really matter is the overall latency - /// for end users. - #[clap(default_value = "20", long, env)] - max_waiting_tokens: usize, - - /// Enforce a maximum number of requests per batch - /// Specific flag for hardware targets that do not support unpadded inference - #[clap(long, env)] - max_batch_size: Option, - - /// Specify the batch sizes to compute cuda graphs for. - /// Use "0" to disable. - /// Default = "1,2,4,8,16,32" - #[clap(long, env, value_delimiter = ',')] - cuda_graphs: Option>, - - /// The IP address to listen on - #[clap(default_value = "0.0.0.0", long, env)] - hostname: String, - - /// The port to listen on. - #[clap(default_value = "3000", long, short, env)] - port: u16, - - /// The name of the socket for gRPC communication between the webserver - /// and the shards. - #[clap(default_value = "/tmp/text-generation-server", long, env)] - shard_uds_path: String, - - /// The address the master shard will listen on. (setting used by torch distributed) - #[clap(default_value = "localhost", long, env)] - master_addr: String, - - /// The address the master port will listen on. (setting used by torch distributed) - #[clap(default_value = "29500", long, env)] - master_port: usize, - - /// The location of the huggingface hub cache. - /// Used to override the location if you want to provide a mounted disk for instance - #[clap(long, env)] - huggingface_hub_cache: Option, - - /// The location of the huggingface hub cache. - /// Used to override the location if you want to provide a mounted disk for instance - #[clap(long, env)] - weights_cache_override: Option, - - /// For some models (like bloom), text-generation-inference implemented custom - /// cuda kernels to speed up inference. Those kernels were only tested on A100. - /// Use this flag to disable them if you're running on different hardware and - /// encounter issues. - #[clap(long, env)] - disable_custom_kernels: bool, - - /// Limit the CUDA available memory. - /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction. - #[clap(default_value = "1.0", long, env)] - cuda_memory_fraction: f32, - - /// Rope scaling will only be used for RoPE models - /// and allow rescaling the position rotary to accomodate for - /// larger prompts. - /// - /// Goes together with `rope_factor`. - /// - /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0 - /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0 - /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed - /// basically) - /// - /// `--rope-scaling linear --rope-factor` fully describes the scaling you want - #[clap(long, env)] - rope_scaling: Option, - - /// Rope scaling will only be used for RoPE models - /// See `rope_scaling` - #[clap(long, env)] - rope_factor: Option, - - /// Outputs the logs in JSON format (useful for telemetry) - #[clap(long, env)] - json_output: bool, - - #[clap(long, env)] - otlp_endpoint: Option, - - #[clap(long, env)] - cors_allow_origin: Vec, - #[clap(long, env)] - watermark_gamma: Option, - #[clap(long, env)] - watermark_delta: Option, - - /// Enable ngrok tunneling - #[clap(long, env)] - ngrok: bool, - - /// ngrok authentication token - #[clap(long, env)] - ngrok_authtoken: Option, - - /// ngrok edge - #[clap(long, env)] - ngrok_edge: Option, - - /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may - /// include a `chat_template`. If not provided, the default config will be used from the model hub. - #[clap(long, env)] - tokenizer_config_path: Option, - - /// Disable outlines grammar constrained generation. - /// This is a feature that allows you to generate text that follows a specific grammar. - #[clap(long, env)] - disable_grammar_support: bool, - - /// Display a lot of information about your runtime environment - #[clap(long, short, action)] - env: bool, - - /// Control the maximum number of inputs that a client can send in a single request - #[clap(default_value = "4", long, env)] - max_client_batch_size: usize, -} - -#[derive(Debug)] -enum ShardStatus { - Ready, - Failed(usize), -} - -#[allow(clippy::too_many_arguments)] -fn shard_manager( - model_id: String, - revision: Option, - quantize: Option, - speculate: Option, - dtype: Option, - trust_remote_code: bool, - uds_path: String, - rank: usize, - world_size: usize, - master_addr: String, - master_port: usize, - huggingface_hub_cache: Option, - weights_cache_override: Option, - disable_custom_kernels: bool, - watermark_gamma: Option, - watermark_delta: Option, - cuda_graphs: Vec, - cuda_memory_fraction: f32, - rope_scaling: Option, - rope_factor: Option, - max_total_tokens: usize, - max_batch_size: Option, - otlp_endpoint: Option, - log_level: LevelFilter, - status_sender: mpsc::Sender, - shutdown: Arc, - _shutdown_sender: mpsc::Sender<()>, -) { - // Enter shard-manager tracing span - let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); - - // Get UDS path - let uds_string = format!("{uds_path}-{rank}"); - let uds = Path::new(&uds_string); - // Clean previous runs - if uds.exists() { - fs::remove_file(uds).unwrap(); - } - - // Process args - let mut shard_args = vec![ - "serve".to_string(), - model_id, - "--uds-path".to_string(), - uds_path, - "--logger-level".to_string(), - log_level.to_string().to_uppercase(), - "--json-output".to_string(), - ]; - - // Activate trust remote code - if trust_remote_code { - shard_args.push("--trust-remote-code".to_string()); - } - - // Activate tensor parallelism - if world_size > 1 { - shard_args.push("--sharded".to_string()); - } - - if let Some(quantize) = quantize { - shard_args.push("--quantize".to_string()); - shard_args.push(quantize.to_string()) - } - - if let Some(speculate) = speculate { - shard_args.push("--speculate".to_string()); - shard_args.push(speculate.to_string()) - } - - if let Some(dtype) = dtype { - shard_args.push("--dtype".to_string()); - shard_args.push(dtype.to_string()) - } - - // Model optional revision - if let Some(revision) = revision { - shard_args.push("--revision".to_string()); - shard_args.push(revision) - } - - let rope = match (rope_scaling, rope_factor) { - (None, None) => None, - (Some(scaling), None) => Some((scaling, 1.0)), - (Some(scaling), Some(factor)) => Some((scaling, factor)), - (None, Some(factor)) => Some((RopeScaling::Linear, factor)), - }; - - // OpenTelemetry - if let Some(otlp_endpoint) = otlp_endpoint { - shard_args.push("--otlp-endpoint".to_string()); - shard_args.push(otlp_endpoint); - } - - // Copy current process env - let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Remove LOG_LEVEL if present - envs.retain(|(name, _)| name != "LOG_LEVEL"); - - // Torch Distributed Env vars - envs.push(("RANK".into(), rank.to_string().into())); - envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); - envs.push(("MASTER_ADDR".into(), master_addr.into())); - envs.push(("MASTER_PORT".into(), master_port.to_string().into())); - envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into())); - - // CUDA memory fraction - envs.push(( - "CUDA_MEMORY_FRACTION".into(), - cuda_memory_fraction.to_string().into(), - )); - - // Safetensors load fast - envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); - - // Disable progress bar - envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); - - // Enable hf transfer for insane download speeds - let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - envs.push(( - "HF_HUB_ENABLE_HF_TRANSFER".into(), - enable_hf_transfer.into(), - )); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - // Detect rope scaling - // Sending as env instead of CLI args to not bloat everything - // those only can be used by RoPE models, so passing information around - // for all models will complexify code unnecessarily - if let Some((scaling, factor)) = rope { - envs.push(("ROPE_SCALING".into(), scaling.to_string().into())); - envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); - } - - envs.push(( - "MAX_TOTAL_TOKENS".into(), - max_total_tokens.to_string().into(), - )); - if let Some(max_batch_size) = max_batch_size { - envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); - } - - // If huggingface_hub_cache is some, pass it to the shard - // Useful when running inside a docker container - if let Some(huggingface_hub_cache) = huggingface_hub_cache { - envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); - }; - - // If weights_cache_override is some, pass it to the shard - // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = weights_cache_override { - envs.push(( - "WEIGHTS_CACHE_OVERRIDE".into(), - weights_cache_override.into(), - )); - }; - - // Enable experimental support for cuda graphs - if !cuda_graphs.is_empty() { - envs.push(( - "CUDA_GRAPHS".into(), - cuda_graphs - .into_iter() - .map(|c| c.to_string()) - .collect::>() - .join(",") - .into(), - )); - } - - // If disable_custom_kernels is true, pass it to the shard as an env var - if disable_custom_kernels { - envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) - } - - // Watermark Gamma - if let Some(watermark_gamma) = watermark_gamma { - envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) - } - - // Watermark Delta - if let Some(watermark_delta) = watermark_delta { - envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) - } - - // Start process - tracing::info!("Starting shard"); - let mut p = match Command::new("text-generation-server") - .args(shard_args) - .env_clear() - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() - { - Ok(p) => p, - Err(err) => { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } - { - tracing::error!("{}", err); - } - - status_sender.send(ShardStatus::Failed(rank)).unwrap(); - return; - } - }; - - // Redirect STDOUT to the console - let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); - let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); - - //stdout tracing thread - thread::spawn(move || { - log_lines(shard_stdout_reader.lines()); - }); - // We read stderr in another thread as it seems that lines() can block in some cases - let (err_sender, err_receiver) = mpsc::channel(); - thread::spawn(move || { - for line in shard_stderr_reader.lines().map_while(Result::ok) { - err_sender.send(line).unwrap_or(()); - } - }); - - let mut ready = false; - let start_time = Instant::now(); - let mut wait_time = Instant::now(); - loop { - // Process exited - if let Some(exit_status) = p.try_wait().unwrap() { - let mut err = String::new(); - while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { - err = err + "\n" + &line; - } - - tracing::error!("Shard complete standard error output:\n{err}"); - - if let Some(signal) = exit_status.signal() { - tracing::error!("Shard process was signaled to shutdown with signal {signal}"); - } - - status_sender.send(ShardStatus::Failed(rank)).unwrap(); - return; - } - - // We received a shutdown signal - if shutdown.load(Ordering::SeqCst) { - terminate("shard", p, Duration::from_secs(90)).unwrap(); - return; - } - - // Shard is ready - if uds.exists() && !ready { - tracing::info!("Shard ready in {:?}", start_time.elapsed()); - status_sender.send(ShardStatus::Ready).unwrap(); - ready = true; - } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard to be ready..."); - wait_time = Instant::now(); - } - sleep(Duration::from_millis(100)); - } -} - -fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>) { - tracing::info!("Shutting down shards"); - // Update shutdown value to true - // This will be picked up by the shard manager - shutdown.store(true, Ordering::SeqCst); - - // Wait for shards to shutdown - // This will block till all shutdown_sender are dropped - let _ = shutdown_receiver.recv(); -} - -fn num_cuda_devices() -> Option { - let devices = match env::var("CUDA_VISIBLE_DEVICES") { - Ok(devices) => devices, - Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, - }; - let n_devices = devices.split(',').count(); - Some(n_devices) -} - -#[derive(Deserialize)] -#[serde(rename_all = "UPPERCASE")] -enum PythonLogLevelEnum { - Trace, - Debug, - Info, - Success, - Warning, - Error, - Critical, -} - -#[derive(Deserialize)] -struct PythonLogLevel { - name: PythonLogLevelEnum, -} - -#[derive(Deserialize)] -struct PythonLogRecord { - level: PythonLogLevel, -} - -#[derive(Deserialize)] -struct PythonLogMessage { - text: String, - record: PythonLogRecord, -} - -impl PythonLogMessage { - fn trace(&self) { - match self.record.level.name { - PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()), - PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()), - PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()), - PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()), - PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()), - PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()), - PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()), - } - } -} - -impl TryFrom<&String> for PythonLogMessage { - type Error = serde_json::Error; - - fn try_from(value: &String) -> Result { - serde_json::from_str::(value) - } -} - -fn log_lines(lines: Lines) { - for line in lines.map_while(Result::ok) { - match PythonLogMessage::try_from(&line) { - Ok(log) => log.trace(), - Err(_) => tracing::debug!("{line}"), - } - } -} - -fn find_num_shards( - sharded: Option, - num_shard: Option, -) -> Result { - // get the number of shards given `sharded` and `num_shard` - let num_shard = match (sharded, num_shard) { - (Some(true), None) => { - // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); - let n_devices = num_cuda_devices() - .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); - if n_devices <= 1 { - return Err(LauncherError::NotEnoughCUDADevices(format!( - "`sharded` is true but only found {n_devices} CUDA devices" - ))); - } - n_devices - } - (Some(true), Some(num_shard)) => { - // we can't have only one shard while sharded - if num_shard <= 1 { - return Err(LauncherError::ArgumentValidation( - "`sharded` is true but `num_shard` <= 1".to_string(), - )); - } - num_shard - } - (Some(false), Some(num_shard)) => num_shard, - (Some(false), None) => 1, - (None, None) => num_cuda_devices().unwrap_or(1), - (None, Some(num_shard)) => num_shard, - }; - if num_shard < 1 { - return Err(LauncherError::ArgumentValidation( - "`num_shard` cannot be < 1".to_string(), - )); - } - Ok(num_shard) -} - -#[derive(Debug, Error)] -enum LauncherError { - #[error("Invalid argument: {0}")] - ArgumentValidation(String), - #[error("not enough cuda devices: {0}")] - NotEnoughCUDADevices(String), - #[error("Download error")] - DownloadError, - #[error("Shard cannot start")] - ShardCannotStart, - #[error("Shard disconnected")] - ShardDisconnected, - #[error("Shard failed")] - ShardFailed, - #[error("Webserver failed")] - WebserverFailed, - #[error("Webserver cannot start")] - WebserverCannotStart, -} - -fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { - // Enter download tracing span - let _span = tracing::span!(tracing::Level::INFO, "download").entered(); - - let mut download_args = vec![ - "download-weights".to_string(), - args.model_id.to_string(), - "--extension".to_string(), - ".safetensors".to_string(), - "--logger-level".to_string(), - "INFO".to_string(), - "--json-output".to_string(), - ]; - - // Model optional revision - if let Some(revision) = &args.revision { - download_args.push("--revision".to_string()); - download_args.push(revision.to_string()) - } - - // Trust remote code for automatic peft fusion - if args.trust_remote_code { - download_args.push("--trust-remote-code".to_string()); - } - - // Copy current process env - let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Remove LOG_LEVEL if present - envs.retain(|(name, _)| name != "LOG_LEVEL"); - - // Disable progress bar - envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); - - // If huggingface_hub_cache is set, pass it to the download process - // Useful when running inside a docker container - if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { - envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); - }; - - // Enable hf transfer for insane download speeds - let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - envs.push(( - "HF_HUB_ENABLE_HF_TRANSFER".into(), - enable_hf_transfer.into(), - )); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - // If args.weights_cache_override is some, pass it to the download process - // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = &args.weights_cache_override { - envs.push(( - "WEIGHTS_CACHE_OVERRIDE".into(), - weights_cache_override.into(), - )); - }; - - // Start process - tracing::info!("Starting download process."); - let mut download_process = match Command::new("text-generation-server") - .args(download_args) - .env_clear() - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() - { - Ok(p) => p, - Err(err) => { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } else { - tracing::error!("{}", err); - } - - return Err(LauncherError::DownloadError); - } - }; - - let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); - - thread::spawn(move || { - log_lines(download_stdout.lines()); - }); - - let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); - - // We read stderr in another thread as it seems that lines() can block in some cases - let (err_sender, err_receiver) = mpsc::channel(); - thread::spawn(move || { - for line in download_stderr.lines().map_while(Result::ok) { - err_sender.send(line).unwrap_or(()); - } - }); - - loop { - if let Some(status) = download_process.try_wait().unwrap() { - if status.success() { - tracing::info!("Successfully downloaded weights."); - break; - } - - let mut err = String::new(); - while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { - err = err + "\n" + &line; - } - - if let Some(signal) = status.signal() { - tracing::error!( - "Download process was signaled to shutdown with signal {signal}: {err}" - ); - } else { - tracing::error!("Download encountered an error: {err}"); - } - - return Err(LauncherError::DownloadError); - } - if !running.load(Ordering::SeqCst) { - terminate("download", download_process, Duration::from_secs(10)).unwrap(); - return Ok(()); - } - sleep(Duration::from_millis(100)); - } - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -fn spawn_shards( - num_shard: usize, - args: &Args, - cuda_graphs: Vec, - max_total_tokens: usize, - max_log_level: LevelFilter, - shutdown: Arc, - shutdown_receiver: &mpsc::Receiver<()>, - shutdown_sender: mpsc::Sender<()>, - status_receiver: &mpsc::Receiver, - status_sender: mpsc::Sender, - running: Arc, -) -> Result<(), LauncherError> { - // Start shard processes - for rank in 0..num_shard { - let model_id = args.model_id.clone(); - let revision = args.revision.clone(); - let uds_path = args.shard_uds_path.clone(); - let master_addr = args.master_addr.clone(); - let huggingface_hub_cache = args.huggingface_hub_cache.clone(); - let weights_cache_override = args.weights_cache_override.clone(); - let status_sender = status_sender.clone(); - let shutdown = shutdown.clone(); - let shutdown_sender = shutdown_sender.clone(); - let otlp_endpoint = args.otlp_endpoint.clone(); - let quantize = args.quantize; - let speculate = args.speculate; - let dtype = args.dtype; - let trust_remote_code = args.trust_remote_code; - let master_port = args.master_port; - let disable_custom_kernels = args.disable_custom_kernels; - let watermark_gamma = args.watermark_gamma; - let watermark_delta = args.watermark_delta; - let cuda_graphs_clone = cuda_graphs.clone(); - let cuda_memory_fraction = args.cuda_memory_fraction; - let rope_scaling = args.rope_scaling; - let rope_factor = args.rope_factor; - let max_batch_size = args.max_batch_size; - thread::spawn(move || { - shard_manager( - model_id, - revision, - quantize, - speculate, - dtype, - trust_remote_code, - uds_path, - rank, - num_shard, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - watermark_gamma, - watermark_delta, - cuda_graphs_clone, - cuda_memory_fraction, - rope_scaling, - rope_factor, - max_total_tokens, - max_batch_size, - otlp_endpoint, - max_log_level, - status_sender, - shutdown, - shutdown_sender, - ) - }); - } - drop(shutdown_sender); - - // Wait for shard to start - let mut shard_ready = 0; - while running.load(Ordering::SeqCst) { - match status_receiver.try_recv() { - Ok(ShardStatus::Ready) => { - shard_ready += 1; - if shard_ready == num_shard { - break; - } - } - Err(TryRecvError::Empty) => { - sleep(Duration::from_millis(100)); - } - Ok(ShardStatus::Failed(rank)) => { - tracing::error!("Shard {rank} failed to start"); - shutdown_shards(shutdown, shutdown_receiver); - return Err(LauncherError::ShardCannotStart); - } - Err(TryRecvError::Disconnected) => { - tracing::error!("Shard status channel disconnected"); - shutdown_shards(shutdown, shutdown_receiver); - return Err(LauncherError::ShardDisconnected); - } - } - } - Ok(()) -} - -fn compute_type(num_shard: usize) -> Option { - let output = Command::new("nvidia-smi") - .args(["--query-gpu=gpu_name", "--format=csv"]) - .output() - .ok()?; - let output = String::from_utf8(output.stdout).ok()?; - let fullname = output.split('\n').nth(1)?; - let cardname = fullname.replace(' ', "-").to_lowercase(); - let compute_type = format!("{num_shard}-{cardname}"); - Some(compute_type) -} - -fn spawn_webserver( - num_shard: usize, - args: Args, - max_input_tokens: usize, - max_total_tokens: usize, - max_batch_prefill_tokens: u32, - shutdown: Arc, - shutdown_receiver: &mpsc::Receiver<()>, -) -> Result { - // All shard started - // Start webserver - tracing::info!("Starting Webserver"); - let mut router_args = vec![ - "--max-client-batch-size".to_string(), - args.max_client_batch_size.to_string(), - "--max-concurrent-requests".to_string(), - args.max_concurrent_requests.to_string(), - "--max-best-of".to_string(), - args.max_best_of.to_string(), - "--max-stop-sequences".to_string(), - args.max_stop_sequences.to_string(), - "--max-top-n-tokens".to_string(), - args.max_top_n_tokens.to_string(), - "--max-input-tokens".to_string(), - max_input_tokens.to_string(), - "--max-total-tokens".to_string(), - max_total_tokens.to_string(), - "--max-batch-prefill-tokens".to_string(), - max_batch_prefill_tokens.to_string(), - "--waiting-served-ratio".to_string(), - args.waiting_served_ratio.to_string(), - "--max-waiting-tokens".to_string(), - args.max_waiting_tokens.to_string(), - "--validation-workers".to_string(), - args.validation_workers.to_string(), - "--hostname".to_string(), - args.hostname.to_string(), - "--port".to_string(), - args.port.to_string(), - "--master-shard-uds-path".to_string(), - format!("{}-0", args.shard_uds_path), - "--tokenizer-name".to_string(), - args.model_id, - ]; - - // Grammar support - if args.disable_grammar_support { - router_args.push("--disable-grammar-support".to_string()); - } - - // Tokenizer config path - if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { - router_args.push("--tokenizer-config-path".to_string()); - router_args.push(tokenizer_config_path.to_string()); - } - - // Model optional max batch total tokens - if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { - router_args.push("--max-batch-total-tokens".to_string()); - router_args.push(max_batch_total_tokens.to_string()); - } - - // Router optional max batch size - if let Some(max_batch_size) = args.max_batch_size { - router_args.push("--max-batch-size".to_string()); - router_args.push(max_batch_size.to_string()); - } - - // Model optional revision - if let Some(ref revision) = args.revision { - router_args.push("--revision".to_string()); - router_args.push(revision.to_string()) - } - - if args.json_output { - router_args.push("--json-output".to_string()); - } - - // OpenTelemetry - if let Some(otlp_endpoint) = args.otlp_endpoint { - router_args.push("--otlp-endpoint".to_string()); - router_args.push(otlp_endpoint); - } - - // CORS origins - for origin in args.cors_allow_origin.into_iter() { - router_args.push("--cors-allow-origin".to_string()); - router_args.push(origin); - } - - // Ngrok - if args.ngrok { - router_args.push("--ngrok".to_string()); - router_args.push("--ngrok-authtoken".to_string()); - router_args.push(args.ngrok_authtoken.unwrap()); - router_args.push("--ngrok-edge".to_string()); - router_args.push(args.ngrok_edge.unwrap()); - } - - // Copy current process env - let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - // Parse Compute type - if let Ok(compute_type) = env::var("COMPUTE_TYPE") { - envs.push(("COMPUTE_TYPE".into(), compute_type.into())) - } else if let Some(compute_type) = compute_type(num_shard) { - envs.push(("COMPUTE_TYPE".into(), compute_type.into())) - } - - let mut webserver = match Command::new("text-generation-router") - .args(router_args) - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() - { - Ok(p) => p, - Err(err) => { - tracing::error!("Failed to start webserver: {}", err); - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-router not found in PATH"); - tracing::error!("Please install it with `make install-router`") - } else { - tracing::error!("{}", err); - } - - shutdown_shards(shutdown, shutdown_receiver); - return Err(LauncherError::WebserverCannotStart); - } - }; - - // Redirect STDOUT and STDERR to the console - let webserver_stdout = webserver.stdout.take().unwrap(); - let webserver_stderr = webserver.stderr.take().unwrap(); - - thread::spawn(move || { - let stdout = BufReader::new(webserver_stdout); - let stderr = BufReader::new(webserver_stderr); - for line in stdout.lines() { - println!("{}", line.unwrap()); - } - for line in stderr.lines() { - println!("{}", line.unwrap()); - } - }); - Ok(webserver) -} - -fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result { - tracing::info!("Terminating {process_name}"); - - let terminate_time = Instant::now(); - signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); - - tracing::info!("Waiting for {process_name} to gracefully shutdown"); - while terminate_time.elapsed() < timeout { - if let Some(status) = process.try_wait()? { - tracing::info!("{process_name} terminated"); - return Ok(status); - } - sleep(Duration::from_millis(100)); - } - tracing::info!("Killing {process_name}"); - - process.kill()?; - let exit_status = process.wait()?; - - tracing::info!("{process_name} killed"); - Ok(exit_status) -} +use clap::Parser; +use text_generation_launcher::{launcher_main, Args, LauncherError}; fn main() -> Result<(), LauncherError> { - // Pattern match configuration - let args: Args = Args::parse(); - - // Filter events with LOG_LEVEL - let varname = "LOG_LEVEL"; - let env_filter = if let Ok(log_level) = std::env::var(varname) { - // Override to avoid simple logs to be spammed with tokio level informations - let log_level = match &log_level[..] { - "warn" => "text_generation_launcher=warn,text_generation_router=warn", - "info" => "text_generation_launcher=info,text_generation_router=info", - "debug" => "text_generation_launcher=debug,text_generation_router=debug", - log_level => log_level, - }; - EnvFilter::builder() - .with_default_directive(LevelFilter::INFO.into()) - .parse_lossy(log_level) - } else { - EnvFilter::new("info") - }; - let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO); - - if args.json_output { - tracing_subscriber::fmt() - .with_env_filter(env_filter) - .json() - .init(); - } else { - tracing_subscriber::fmt() - .with_env_filter(env_filter) - .compact() - .init(); - } - - if args.env { - let env_runtime = env_runtime::Env::new(); - tracing::info!("{}", env_runtime); - } - - tracing::info!("{:#?}", args); - - let get_max_position_embeddings = || -> Result> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id - let api = Api::new()?; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; - let config: Config = config.into(); - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok(max_default) - } else { - Ok(max_position_embeddings) - } - } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) - } - }; - let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); - - let max_input_tokens = { - match (args.max_input_tokens, args.max_input_length) { - (Some(max_input_tokens), Some(max_input_length)) => { - return Err(LauncherError::ArgumentValidation( - format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.", - ))); - } - (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, - (None, None) => { - let value = max_position_embeddings - 1; - tracing::info!("Default `max_input_tokens` to {value}"); - value - } - } - }; - let max_total_tokens = { - match args.max_total_tokens { - Some(max_total_tokens) => max_total_tokens, - None => { - let value = max_position_embeddings; - tracing::info!("Default `max_total_tokens` to {value}"); - value - } - } - }; - let max_batch_prefill_tokens = { - match args.max_batch_prefill_tokens { - Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, - None => { - let value: u32 = if let Some(max_batch_size) = args.max_batch_size { - max_batch_size * max_input_tokens - } else { - // Adding some edge in order to account for potential block_size alignement - // issue. - max_input_tokens + 50 - } as u32; - tracing::info!("Default `max_batch_prefill_tokens` to {value}"); - value - } - } - }; - - // Validate args - if max_input_tokens >= max_total_tokens { - return Err(LauncherError::ArgumentValidation( - "`max_input_tokens must be < `max_total_tokens`".to_string(), - )); - } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_input_tokens - ))); - } - - let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { - (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), - #[allow(deprecated)] - ( - None, - Some( - Quantization::Bitsandbytes - | Quantization::BitsandbytesNF4 - | Quantization::BitsandbytesFP4, - ), - ) => { - tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); - vec![] - } - _ => { - let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; - tracing::info!("Using default cuda graphs {cuda_graphs:?}"); - cuda_graphs - } - }; - - if args.validation_workers == 0 { - return Err(LauncherError::ArgumentValidation( - "`validation_workers` must be > 0".to_string(), - )); - } - if args.trust_remote_code { - tracing::warn!( - "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", - args.model_id - ); - } - - let num_shard = find_num_shards(args.sharded, args.num_shard)?; - if num_shard > 1 { - tracing::info!("Sharding model on {num_shard} processes"); - } - - if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_batch_total_tokens - ))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_total_tokens, max_batch_total_tokens - ))); - } - } - - if args.ngrok { - if args.ngrok_authtoken.is_none() { - return Err(LauncherError::ArgumentValidation( - "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), - )); - } - - if args.ngrok_edge.is_none() { - return Err(LauncherError::ArgumentValidation( - "`ngrok-edge` must be set when using ngrok tunneling".to_string(), - )); - } - } - - // Signal handler - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); - - // Download and convert model weights - download_convert_model(&args, running.clone())?; - - if !running.load(Ordering::SeqCst) { - // Launcher was asked to stop - return Ok(()); - } - - // Shared shutdown bool - let shutdown = Arc::new(AtomicBool::new(false)); - // Shared shutdown channel - // When shutting down, the main thread will wait for all senders to be dropped - let (shutdown_sender, shutdown_receiver) = mpsc::channel(); - - // Shared channel to track shard status - let (status_sender, status_receiver) = mpsc::channel(); - - spawn_shards( - num_shard, - &args, - cuda_graphs, - max_total_tokens, - max_log_level, - shutdown.clone(), - &shutdown_receiver, - shutdown_sender, - &status_receiver, - status_sender, - running.clone(), - )?; - - // We might have received a termination signal - if !running.load(Ordering::SeqCst) { - shutdown_shards(shutdown, &shutdown_receiver); - return Ok(()); - } - - let mut webserver = spawn_webserver( - num_shard, - args, - max_input_tokens, - max_total_tokens, - max_batch_prefill_tokens, - shutdown.clone(), - &shutdown_receiver, + let args = Args::parse(); + launcher_main( + args.model_id, + args.revision, + args.validation_workers, + args.sharded, + args.num_shard, + args.quantize, + args.speculate, + args.dtype, + args.trust_remote_code, + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens, + args.max_input_length, + args.max_total_tokens, + args.waiting_served_ratio, + args.max_batch_prefill_tokens, + args.max_batch_total_tokens, + args.max_waiting_tokens, + args.max_batch_size, + args.cuda_graphs, + args.hostname, + args.port, + args.shard_uds_path, + args.master_addr, + args.master_port, + args.huggingface_hub_cache, + args.weights_cache_override, + args.disable_custom_kernels, + args.cuda_memory_fraction, + args.rope_scaling, + args.rope_factor, + args.json_output, + args.otlp_endpoint, + args.cors_allow_origin, + args.watermark_gamma, + args.watermark_delta, + args.ngrok, + args.ngrok_authtoken, + args.ngrok_edge, + args.tokenizer_config_path, + args.disable_grammar_support, + args.env, + args.max_client_batch_size, ) - .map_err(|err| { - shutdown_shards(shutdown.clone(), &shutdown_receiver); - err - })?; - - // Default exit code - let mut exit_code = Ok(()); - - while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { - tracing::error!("Shard {rank} crashed"); - exit_code = Err(LauncherError::ShardFailed); - break; - }; - - match webserver.try_wait().unwrap() { - Some(_) => { - tracing::error!("Webserver Crashed"); - shutdown_shards(shutdown, &shutdown_receiver); - return Err(LauncherError::WebserverFailed); - } - None => { - sleep(Duration::from_millis(100)); - } - }; - } - - // Graceful termination - terminate("webserver", webserver, Duration::from_secs(90)).unwrap(); - shutdown_shards(shutdown, &shutdown_receiver); - - exit_code } diff --git a/router/src/lib.rs b/router/src/lib.rs index 9b3283df..a4950957 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -6,15 +6,491 @@ mod queue; pub mod server; mod validation; +use axum::http::HeaderValue; +use config::Config; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use hf_hub::{Cache, Repo, RepoType}; use infer::{Infer, InferError, InferStreamResponse}; +use opentelemetry::sdk::propagation::TraceContextPropagator; +use opentelemetry::sdk::trace; +use opentelemetry::sdk::trace::Sampler; +use opentelemetry::sdk::Resource; +use opentelemetry::{global, KeyValue}; +use opentelemetry_otlp::WithExportConfig; use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::BufReader; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::{Path, PathBuf}; +use text_generation_client::{ClientError, ShardedClient}; +use thiserror::Error; +use tokenizers::Tokenizer; use tokio::sync::OwnedSemaphorePermit; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::warn; use utoipa::ToSchema; use validation::Validation; +/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: +/// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) +/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) +/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) +fn init_logging(otlp_endpoint: Option, json_output: bool) { + let mut layers = Vec::new(); + + // STDOUT/STDERR layer + let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_ansi(ansi) + .with_line_number(true); + + let fmt_layer = match json_output { + true => fmt_layer.json().flatten_event(true).boxed(), + false => fmt_layer.boxed(), + }; + layers.push(fmt_layer); + + // OpenTelemetry tracing layer + if let Some(otlp_endpoint) = otlp_endpoint { + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(otlp_endpoint), + ) + .with_trace_config( + trace::config() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + "text-generation-inference.router", + )])) + .with_sampler(Sampler::AlwaysOn), + ) + .install_batch(opentelemetry::runtime::Tokio); + + if let Ok(tracer) = tracer { + layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); + init_tracing_opentelemetry::init_propagator().unwrap(); + }; + } + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + + tracing_subscriber::registry() + .with(env_filter) + .with(layers) + .init(); +} + +/// get model info from the Huggingface Hub +pub async fn get_model_info(api: &ApiRepo) -> Option { + let response = api.info_request().send().await.ok()?; + + if response.status().is_success() { + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None + } +} + +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + api_base_repo.get("tokenizer.json").await.ok() + } else { + None + } +} + +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) + .map_err(|e| { + tracing::warn!("Unable to parse tokenizer config: {}", e); + e + }) + .ok()?; + + Some(tokenizer_config) +} + +#[derive(Debug, Error)] +pub enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), + #[error("Axum webserver failed: {0}")] + Axum(#[from] axum::BoxError), +} + +#[allow(clippy::too_many_arguments)] +pub async fn internal_main( + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + hostname: String, + port: u16, + master_shard_uds_path: String, + tokenizer_name: String, + tokenizer_config_path: Option, + revision: Option, + validation_workers: usize, + json_output: bool, + otlp_endpoint: Option, + cors_allow_origin: Option>, + ngrok: bool, + ngrok_authtoken: Option, + ngrok_edge: Option, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, +) -> Result<(), RouterError> { + // Launch Tokio runtime + if otlp_endpoint.is_some() { + // Initialize if OpenTelemetry is enabled + init_logging(otlp_endpoint, json_output); + } + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let cors_allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { + AllowOrigin::list( + cors_allow_origin + .iter() + .map(|origin| origin.parse::().unwrap()), + ) + }); + + // Parse Huggingface hub token + let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + + // Tokenizer instance + // This will only be used to validate payloads + let local_path = Path::new(&tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, + } + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = Cache::default(); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None + } + } + } + } else { + Type::None + }; + + // Load tokenizer and model info + let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + None, + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + + let model_info = if let Some(model_info) = get_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + model_info, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + None, + ) + } + }; + let tokenizer: Option = + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); + let config: Option = config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } + config.ok() + }) + }); + let model_info = model_info.unwrap_or_else(|| HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + tracing::info!("Using config {config:?}"); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } + + // if pipeline-tag == text-generation we default to return_full_text = true + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + true + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + + // Instantiate sharded client from the master unix socket + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(RouterError::Connection)?; + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(RouterError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_supported_batch_total_tokens = match sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(RouterError::Warmup)? + { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + max_batch_total_tokens + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); + } + + max_supported_batch_total_tokens + } + }; + tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); + tracing::info!("Connected"); + + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; + + // Run server + server::run( + model_info, + shard_info, + compat_return_full_text, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_supported_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + sharded_client, + tokenizer, + config, + validation_workers, + addr, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + ) + .await?; + Ok(()) +} + /// Type alias for generation responses pub(crate) type GenerateStreamResponse = ( OwnedSemaphorePermit, diff --git a/router/src/main.rs b/router/src/main.rs index b526367c..ca11801c 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,26 +1,5 @@ -use axum::http::HeaderValue; use clap::Parser; -use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; -use hf_hub::{Cache, Repo, RepoType}; -use opentelemetry::sdk::propagation::TraceContextPropagator; -use opentelemetry::sdk::trace; -use opentelemetry::sdk::trace::Sampler; -use opentelemetry::sdk::Resource; -use opentelemetry::{global, KeyValue}; -use opentelemetry_otlp::WithExportConfig; -use std::fs::File; -use std::io::BufReader; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::path::{Path, PathBuf}; -use text_generation_client::{ClientError, ShardedClient}; -use text_generation_router::config::Config; -use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; -use thiserror::Error; -use tokenizers::Tokenizer; -use tower_http::cors::AllowOrigin; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; +use text_generation_router::{internal_main, RouterError}; /// App Configuration #[derive(Parser, Debug)] @@ -86,487 +65,36 @@ struct Args { async fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); - // Pattern match configuration - let Args { - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - hostname, - port, - master_shard_uds_path, - tokenizer_name, - tokenizer_config_path, - revision, - validation_workers, - json_output, - otlp_endpoint, - cors_allow_origin, - ngrok, - ngrok_authtoken, - ngrok_edge, - messages_api_enabled, - disable_grammar_support, - max_client_batch_size, - } = args; - // Launch Tokio runtime - init_logging(otlp_endpoint, json_output); - - // Validate args - if max_input_tokens >= max_total_tokens { - return Err(RouterError::ArgumentValidation( - "`max_input_tokens` must be < `max_total_tokens`".to_string(), - )); - } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); - } - - if validation_workers == 0 { - return Err(RouterError::ArgumentValidation( - "`validation_workers` must be > 0".to_string(), - )); - } - - if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } - } - - // CORS allowed origins - // map to go inside the option and then map to parse from String to HeaderValue - // Finally, convert to AllowOrigin - let cors_allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { - AllowOrigin::list( - cors_allow_origin - .iter() - .map(|origin| origin.parse::().unwrap()), - ) - }); - - // Parse Huggingface hub token - let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); - - // Tokenizer instance - // This will only be used to validate payloads - let local_path = Path::new(&tokenizer_name); - - // Shared API builder initialization - let api_builder = || { - let mut builder = ApiBuilder::new() - .with_progress(false) - .with_token(authorization_token); - - if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { - builder = builder.with_cache_dir(cache_dir.into()); - } - - builder - }; - - // Decide if we need to use the API based on the revision and local path - let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); - - // Initialize API if needed - #[derive(Clone)] - enum Type { - Api(Api), - Cache(Cache), - None, - } - let api = if use_api { - if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { - let cache = Cache::default(); - tracing::warn!("Offline mode active using cache defaults"); - Type::Cache(cache) - } else { - tracing::info!("Using the Hugging Face API"); - match api_builder().build() { - Ok(api) => Type::Api(api), - Err(_) => { - tracing::warn!("Unable to build the Hugging Face API"); - Type::None - } - } - } - } else { - Type::None - }; - - // Load tokenizer and model info - let ( - tokenizer_filename, - config_filename, - tokenizer_config_filename, - processor_config_filename, - model_info, - ) = match api { - Type::None => ( - Some(local_path.join("tokenizer.json")), - Some(local_path.join("config.json")), - Some(local_path.join("tokenizer_config.json")), - Some(local_path.join("processor_config.json")), - None, - ), - Type::Api(api) => { - let api_repo = api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.clone().unwrap_or_else(|| "main".to_string()), - )); - - let tokenizer_filename = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Some(tokenizer_filename), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; - let config_filename = api_repo.get("config.json").await.ok(); - let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); - let processor_config_filename = api_repo.get("processor_config.json").await.ok(); - - let model_info = if let Some(model_info) = get_model_info(&api_repo).await { - Some(model_info) - } else { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - None - }; - ( - tokenizer_filename, - config_filename, - tokenizer_config_filename, - processor_config_filename, - model_info, - ) - } - Type::Cache(cache) => { - let repo = cache.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.clone().unwrap_or_else(|| "main".to_string()), - )); - ( - repo.get("tokenizer.json"), - repo.get("config.json"), - repo.get("tokenizer_config.json"), - repo.get("processor_config.json"), - None, - ) - } - }; - let tokenizer: Option = - tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); - let config: Option = config_filename.and_then(|filename| { - std::fs::read_to_string(filename) - .ok() - .as_ref() - .and_then(|c| { - let config: Result = serde_json::from_str(c); - if let Err(err) = &config { - tracing::warn!("Could not parse config {err:?}"); - } - config.ok() - }) - }); - let model_info = model_info.unwrap_or_else(|| HubModelInfo { - model_id: tokenizer_name.to_string(), - sha: None, - pipeline_tag: None, - }); - - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path - { - HubTokenizerConfig::from_file(filename) - } else { - tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) - }; - let tokenizer_config = tokenizer_config.unwrap_or_else(|| { - tracing::warn!("Could not find tokenizer config locally and no API specified"); - HubTokenizerConfig::default() - }); - - let processor_config = processor_config_filename - .and_then(HubProcessorConfig::from_file) - .unwrap_or_default(); - - tracing::info!("Using config {config:?}"); - if tokenizer.is_none() { - tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); - tracing::warn!("Rust input length validation and truncation is disabled"); - } - - // if pipeline-tag == text-generation we default to return_full_text = true - let compat_return_full_text = match &model_info.pipeline_tag { - None => { - tracing::warn!("no pipeline tag found for model {tokenizer_name}"); - true - } - Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", - }; - - // Instantiate sharded client from the master unix socket - let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(RouterError::Connection)?; - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(RouterError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_supported_batch_total_tokens = match sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(RouterError::Warmup)? - { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens - .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); - tracing::warn!("Model does not support automatic max batch total tokens"); - max_batch_total_tokens - } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); - } - - max_supported_batch_total_tokens - } - }; - tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); - tracing::info!("Connected"); - - // Determine the server port based on the feature and environment variable. - let port = if cfg!(feature = "google") { - std::env::var("AIP_HTTP_PORT") - .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) - .unwrap_or(port) - } else { - port - }; - - let addr = match hostname.parse() { - Ok(ip) => SocketAddr::new(ip, port), - Err(_) => { - tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) - } - }; - - // Run server - server::run( - model_info, - shard_info, - compat_return_full_text, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_supported_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - sharded_client, - tokenizer, - config, - validation_workers, - addr, - cors_allow_origin, - ngrok, - ngrok_authtoken, - ngrok_edge, - tokenizer_config, - processor_config, - messages_api_enabled, - disable_grammar_support, - max_client_batch_size, + internal_main( + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens, + args.max_total_tokens, + args.waiting_served_ratio, + args.max_batch_prefill_tokens, + args.max_batch_total_tokens, + args.max_waiting_tokens, + args.max_batch_size, + args.hostname, + args.port, + args.master_shard_uds_path, + args.tokenizer_name, + args.tokenizer_config_path, + args.revision, + args.validation_workers, + args.json_output, + args.otlp_endpoint, + args.cors_allow_origin, + args.ngrok, + args.ngrok_authtoken, + args.ngrok_edge, + args.messages_api_enabled, + args.disable_grammar_support, + args.max_client_batch_size, ) .await?; Ok(()) } - -/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: -/// - otlp_endpoint is an optional URL to an Open Telemetry collector -/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) -/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) -/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) -fn init_logging(otlp_endpoint: Option, json_output: bool) { - let mut layers = Vec::new(); - - // STDOUT/STDERR layer - let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); - let fmt_layer = tracing_subscriber::fmt::layer() - .with_file(true) - .with_ansi(ansi) - .with_line_number(true); - - let fmt_layer = match json_output { - true => fmt_layer.json().flatten_event(true).boxed(), - false => fmt_layer.boxed(), - }; - layers.push(fmt_layer); - - // OpenTelemetry tracing layer - if let Some(otlp_endpoint) = otlp_endpoint { - global::set_text_map_propagator(TraceContextPropagator::new()); - - let tracer = opentelemetry_otlp::new_pipeline() - .tracing() - .with_exporter( - opentelemetry_otlp::new_exporter() - .tonic() - .with_endpoint(otlp_endpoint), - ) - .with_trace_config( - trace::config() - .with_resource(Resource::new(vec![KeyValue::new( - "service.name", - "text-generation-inference.router", - )])) - .with_sampler(Sampler::AlwaysOn), - ) - .install_batch(opentelemetry::runtime::Tokio); - - if let Ok(tracer) = tracer { - layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); - init_tracing_opentelemetry::init_propagator().unwrap(); - }; - } - - // Filter events with LOG_LEVEL - let varname = "LOG_LEVEL"; - let env_filter = if let Ok(log_level) = std::env::var(varname) { - // Override to avoid simple logs to be spammed with tokio level informations - let log_level = match &log_level[..] { - "warn" => "text_generation_launcher=warn,text_generation_router=warn", - "info" => "text_generation_launcher=info,text_generation_router=info", - "debug" => "text_generation_launcher=debug,text_generation_router=debug", - log_level => log_level, - }; - EnvFilter::builder() - .with_default_directive(LevelFilter::INFO.into()) - .parse_lossy(log_level) - } else { - EnvFilter::new("info") - }; - - tracing_subscriber::registry() - .with(env_filter) - .with(layers) - .init(); -} - -/// get model info from the Huggingface Hub -pub async fn get_model_info(api: &ApiRepo) -> Option { - let response = api.info_request().send().await.ok()?; - - if response.status().is_success() { - let hub_model_info: HubModelInfo = - serde_json::from_str(&response.text().await.ok()?).ok()?; - if let Some(sha) = &hub_model_info.sha { - tracing::info!( - "Serving revision {sha} of model {}", - hub_model_info.model_id - ); - } - Some(hub_model_info) - } else { - None - } -} - -/// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { - let config_filename = api_repo.get("config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of `User`. - let config: serde_json::Value = serde_json::from_reader(reader).ok()?; - - if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { - let api_base_repo = api.repo(Repo::with_revision( - base_model_id.to_string(), - RepoType::Model, - "main".to_string(), - )); - - api_base_repo.get("tokenizer.json").await.ok() - } else { - None - } -} - -/// get tokenizer_config from the Huggingface Hub -pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { - let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(tokenizer_config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) - .map_err(|e| { - tracing::warn!("Unable to parse tokenizer config: {}", e); - e - }) - .ok()?; - - Some(tokenizer_config) -} - -#[derive(Debug, Error)] -enum RouterError { - #[error("Argument validation error: {0}")] - ArgumentValidation(String), - #[error("Unable to connect to the Python model shards: {0}")] - Connection(ClientError), - #[error("Unable to clear the Python model shards cache: {0}")] - Cache(ClientError), - #[error("Unable to get the Python model shards info: {0}")] - Info(ClientError), - #[error("Unable to warmup the Python model shards: {0}")] - Warmup(ClientError), - #[error("Tokio runtime failed to start: {0}")] - Tokio(#[from] std::io::Error), - #[error("Axum webserver failed: {0}")] - Axum(#[from] axum::BoxError), -} diff --git a/tgi/.gitignore b/tgi/.gitignore new file mode 100644 index 00000000..c8f04429 --- /dev/null +++ b/tgi/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version diff --git a/tgi/Cargo.toml b/tgi/Cargo.toml new file mode 100644 index 00000000..942fe892 --- /dev/null +++ b/tgi/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "tgi" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "tgi" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.20.0", features = ["extension-module"] } +pyo3-asyncio = { version = "0.20.0", features = ["tokio-runtime"] } +tokio = "1.4" +text-generation-router = { path = "../router" } +text-generation-launcher = { path = "../launcher" } diff --git a/tgi/Makefile b/tgi/Makefile new file mode 100644 index 00000000..a06df033 --- /dev/null +++ b/tgi/Makefile @@ -0,0 +1,6 @@ + +build: + maturin build + +install: build + pip install -e . diff --git a/tgi/README.md b/tgi/README.md new file mode 100644 index 00000000..d45cd9a1 --- /dev/null +++ b/tgi/README.md @@ -0,0 +1,47 @@ +# TGI (Python Package) + +> [!IMPORTANT] +> This is an experimental package and intended for research purposes only. The package is likely to change and should only be used for testing and development. + +`tgi` is a simple Python package that wraps the `text-generation-server` and `text-generation-launcher` packages. It provides a simple interface to the text generation server. + +```bash +make install +# this compiles the code and runs pip install for `tgi` +``` + +## Usage + +See the full example in the [`app.py`](./app.py) file. + +```python +from tgi import TGI +from huggingface_hub import InferenceClient +import time + +llm = TGI(model_id="google/paligemma-3b-mix-224") + +# ✂️ startup logic snipped +print("Model is ready!") + +client = InferenceClient("http://localhost:3000") +generated = client.text_generation("What are the main characteristics of a cat?") +print(generated) + +# Cats are known for their independent nature, curious minds, and affectionate nature. Here are the main characteristics of a cat... + +llm.close() +``` + +## How it works + +Technically this is a `pyo3` package that wraps the `text-generation-server` and `text-generation-launcher` packages, and slightly modifies the launcher to rely on the interal code rather than launch an external binary. + +## Known issues/limitations + +- [ ] server does not gracefully handle shutdowns (trying to avoid python context for better notebook dev experience) +- [ ] issues with tracing (launcher and router should share tracer) +- [ ] text-generation-server is not integrated and still relies on the external install +- [ ] not all parameters are exposed/passed through +- [ ] general cleanup and refactoring needed +- [ ] review naming and explore pushing to PyPi diff --git a/tgi/app.py b/tgi/app.py new file mode 100644 index 00000000..5e7c0d6e --- /dev/null +++ b/tgi/app.py @@ -0,0 +1,38 @@ +from tgi import TGI +from huggingface_hub import InferenceClient +import time + +llm = TGI(model_id="google/paligemma-3b-mix-224") +client = InferenceClient("http://localhost:3000") + +while True: + print("Waiting for the model to be ready...") + try: + time.sleep(5) + generated = client.text_generation("What is Deep Learning?") + break + except Exception as e: + print(e) + +print("Model is ready!") + +time.sleep(2) + +# do a couple of inference requests +print("Generating text...") +generated = client.text_generation("Where is the capital of France?") +print(generated) + +time.sleep(2) + +generated = client.text_generation( + "What can you tell me about the history of the United States?" +) +print(generated) + +time.sleep(2) + +generated = client.text_generation("What are the main characteristics of a cat?") +print(generated) + +llm.close() diff --git a/tgi/pyproject.toml b/tgi/pyproject.toml new file mode 100644 index 00000000..cb103d0b --- /dev/null +++ b/tgi/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["maturin>=1.5,<2.0"] +build-backend = "maturin" + +[project] +name = "tgi" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/tgi/src/lib.rs b/tgi/src/lib.rs new file mode 100644 index 00000000..793dc64e --- /dev/null +++ b/tgi/src/lib.rs @@ -0,0 +1,455 @@ +use pyo3::{prelude::*, wrap_pyfunction}; +use text_generation_launcher::{launcher_main, launcher_main_without_server}; +use text_generation_router::internal_main; + +#[allow(clippy::too_many_arguments)] +#[pyfunction] +#[pyo3(signature = ( + model_id, + revision, + validation_workers, + sharded, + num_shard, + _quantize, + speculate, + _dtype, + trust_remote_code, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + cuda_graphs, + hostname, + port, + shard_uds_path, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + cuda_memory_fraction, + _rope_scaling, + rope_factor, + json_output, + otlp_endpoint, + cors_allow_origin, + watermark_gamma, + watermark_delta, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config_path, + disable_grammar_support, + env, + max_client_batch_size, +))] +fn rust_launcher( + py: Python<'_>, + model_id: String, + revision: Option, + validation_workers: usize, + sharded: Option, + num_shard: Option, + _quantize: Option, // Option, + speculate: Option, + _dtype: Option, // Option, + trust_remote_code: bool, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: Option, + max_input_length: Option, + max_total_tokens: Option, + waiting_served_ratio: f32, + max_batch_prefill_tokens: Option, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + cuda_graphs: Option>, + hostname: String, + port: u16, + shard_uds_path: String, + master_addr: String, + master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, + disable_custom_kernels: bool, + cuda_memory_fraction: f32, + _rope_scaling: Option, // Option, + rope_factor: Option, + json_output: bool, + otlp_endpoint: Option, + cors_allow_origin: Vec, + watermark_gamma: Option, + watermark_delta: Option, + ngrok: bool, + ngrok_authtoken: Option, + ngrok_edge: Option, + tokenizer_config_path: Option, + disable_grammar_support: bool, + env: bool, + max_client_batch_size: usize, +) -> PyResult<&PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + launcher_main( + model_id, + revision, + validation_workers, + sharded, + num_shard, + None, + speculate, + None, + trust_remote_code, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + cuda_graphs, + hostname, + port, + shard_uds_path, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + cuda_memory_fraction, + None, + rope_factor, + json_output, + otlp_endpoint, + cors_allow_origin, + watermark_gamma, + watermark_delta, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config_path, + disable_grammar_support, + env, + max_client_batch_size, + ) + .unwrap(); + + Ok(Python::with_gil(|py| py.None())) + }) +} + +#[allow(clippy::too_many_arguments)] +#[pyfunction] +#[pyo3(signature = ( + model_id, + revision, + validation_workers, + sharded, + num_shard, + _quantize, + speculate, + _dtype, + trust_remote_code, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + cuda_graphs, + hostname, + port, + shard_uds_path, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + cuda_memory_fraction, + _rope_scaling, + rope_factor, + json_output, + otlp_endpoint, + cors_allow_origin, + watermark_gamma, + watermark_delta, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config_path, + disable_grammar_support, + env, + max_client_batch_size, +))] +fn fully_packaged( + py: Python<'_>, + model_id: String, + revision: Option, + validation_workers: usize, + sharded: Option, + num_shard: Option, + _quantize: Option, // Option, + speculate: Option, + _dtype: Option, // Option, + trust_remote_code: bool, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: Option, + max_input_length: Option, + max_total_tokens: Option, + waiting_served_ratio: f32, + max_batch_prefill_tokens: Option, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + cuda_graphs: Option>, + hostname: String, + port: u16, + shard_uds_path: String, + master_addr: String, + master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, + disable_custom_kernels: bool, + cuda_memory_fraction: f32, + _rope_scaling: Option, // Option, + rope_factor: Option, + json_output: bool, + otlp_endpoint: Option, + cors_allow_origin: Vec, + watermark_gamma: Option, + watermark_delta: Option, + ngrok: bool, + ngrok_authtoken: Option, + ngrok_edge: Option, + tokenizer_config_path: Option, + disable_grammar_support: bool, + env: bool, + max_client_batch_size: usize, +) -> PyResult<&PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + use std::thread; + use tokio::runtime::Runtime; + + let model_id_clone = model_id.clone(); + let max_concurrent_requests_clone = max_concurrent_requests; + let max_best_of_clone = max_best_of; + let max_stop_sequences_clone = max_stop_sequences; + let max_top_n_tokens_clone = max_top_n_tokens; + let max_input_tokens_clone = max_input_tokens.unwrap_or(1024); + let max_total_tokens_clone = max_total_tokens.unwrap_or(2048); + let waiting_served_ratio_clone = waiting_served_ratio; + + let max_batch_prefill_tokens_clone = max_batch_prefill_tokens.unwrap_or(4096); + let max_batch_total_tokens_clone = max_batch_total_tokens; + let max_waiting_tokens_clone = max_waiting_tokens; + let max_batch_size_clone = max_batch_size; + let hostname_clone = hostname.clone(); + let port_clone = port; + + // TODO: fix this + let _shard_uds_path_clone = shard_uds_path.clone(); + + let tokenizer_config_path = tokenizer_config_path.clone(); + let revision = revision.clone(); + let validation_workers = validation_workers; + let json_output = json_output; + + let otlp_endpoint = otlp_endpoint.clone(); + let cors_allow_origin = cors_allow_origin.clone(); + let ngrok = ngrok; + let ngrok_authtoken = ngrok_authtoken.clone(); + let ngrok_edge = ngrok_edge.clone(); + let messages_api_enabled = true; + let disable_grammar_support = disable_grammar_support; + let max_client_batch_size = max_client_batch_size; + + let ngrok_clone = ngrok; + let ngrok_authtoken_clone = ngrok_authtoken.clone(); + let ngrok_edge_clone = ngrok_edge.clone(); + let messages_api_enabled_clone = messages_api_enabled; + let disable_grammar_support_clone = disable_grammar_support; + let max_client_batch_size_clone = max_client_batch_size; + + let tokenizer_config_path_clone = tokenizer_config_path.clone(); + let revision_clone = revision.clone(); + let validation_workers_clone = validation_workers; + let json_output_clone = json_output; + let otlp_endpoint_clone = otlp_endpoint.clone(); + + let webserver_callback = move || { + let handle = thread::spawn(move || { + let rt = Runtime::new().unwrap(); + rt.block_on(async { + internal_main( + max_concurrent_requests_clone, + max_best_of_clone, + max_stop_sequences_clone, + max_top_n_tokens_clone, + max_input_tokens_clone, + max_total_tokens_clone, + waiting_served_ratio_clone, + max_batch_prefill_tokens_clone, + max_batch_total_tokens_clone, + max_waiting_tokens_clone, + max_batch_size_clone, + hostname_clone, + port_clone, + "/tmp/text-generation-server-0".to_string(), + model_id_clone, + tokenizer_config_path_clone, + revision_clone, + validation_workers_clone, + json_output_clone, + otlp_endpoint_clone, + None, + ngrok_clone, + ngrok_authtoken_clone, + ngrok_edge_clone, + messages_api_enabled_clone, + disable_grammar_support_clone, + max_client_batch_size_clone, + ) + .await + }) + }); + match handle.join() { + Ok(_) => println!("Server exited successfully"), + Err(e) => println!("Server exited with error: {:?}", e), + } + Ok(()) + }; + + // parse the arguments and run the main function + launcher_main_without_server( + model_id, + revision, + validation_workers, + sharded, + num_shard, + None, + speculate, + None, + trust_remote_code, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + cuda_graphs, + hostname, + port, + shard_uds_path, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + cuda_memory_fraction, + None, + rope_factor, + json_output, + otlp_endpoint, + cors_allow_origin, + watermark_gamma, + watermark_delta, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config_path, + disable_grammar_support, + env, + max_client_batch_size, + Box::new(webserver_callback), + ) + .unwrap(); + + Ok(Python::with_gil(|py| py.None())) + }) +} + +/// Asynchronous sleep function. +#[pyfunction] +fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async { + tokio::time::sleep(std::time::Duration::from_secs(20)).await; + Ok(Python::with_gil(|py| py.None())) + }) +} + +// TODO: remove hardcoding +#[pyfunction] +fn rust_server(py: Python<'_>) -> PyResult<&PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async { + let _ = internal_main( + 128, // max_concurrent_requests: usize, + 2, // max_best_of: usize, + 4, // max_stop_sequences: usize, + 5, // max_top_n_tokens: u32, + 1024, // max_input_tokens: usize, + 2048, // max_total_tokens: usize, + 1.2, // waiting_served_ratio: f32, + 4096, // max_batch_prefill_tokens: u32, + None, // max_batch_total_tokens: Option, + 20, // max_waiting_tokens: usize, + None, // max_batch_size: Option, + "0.0.0.0".to_string(), // hostname: String, + 3000, // port: u16, + "/tmp/text-generation-server-0".to_string(), // master_shard_uds_path: String, + "llava-hf/llava-v1.6-mistral-7b-hf".to_string(), // tokenizer_name: String, + None, // tokenizer_config_path: Option, + None, // revision: Option, + 2, // validation_workers: usize, + false, // json_output: bool, + None, // otlp_endpoint: Option, + None, // cors_allow_origin: Option>, + false, // ngrok: bool, + None, // ngrok_authtoken: Option, + None, // ngrok_edge: Option, + false, // messages_api_enabled: bool, + false, // disable_grammar_support: bool, + 4, // max_client_batch_size: usize, + ) + .await; + Ok(Python::with_gil(|py| py.None())) + }) +} + +#[pymodule] +fn tgi(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(rust_sleep, m)?)?; + m.add_function(wrap_pyfunction!(rust_server, m)?)?; + m.add_function(wrap_pyfunction!(rust_launcher, m)?)?; + m.add_function(wrap_pyfunction!(fully_packaged, m)?)?; + Ok(()) +} diff --git a/tgi/tgi/__init__.py b/tgi/tgi/__init__.py new file mode 100644 index 00000000..2826915e --- /dev/null +++ b/tgi/tgi/__init__.py @@ -0,0 +1,132 @@ +from .tgi import * +import threading +from tgi import rust_launcher, rust_sleep, fully_packaged +import asyncio +from dataclasses import dataclass, asdict +import sys + +# add the rust_launcher coroutine to the __all__ list +__doc__ = tgi.__doc__ +if hasattr(tgi, "__all__"): + __all__ = tgi.__all__ + + +@dataclass +class Args: + model_id = "google/gemma-2b-it" + revision = None + validation_workers = 2 + sharded = None + num_shard = None + quantize = None + speculate = None + dtype = None + trust_remote_code = True + max_concurrent_requests = 128 + max_best_of = 2 + max_stop_sequences = 4 + max_top_n_tokens = 5 + max_input_tokens = None + max_input_length = None + max_total_tokens = None + waiting_served_ratio = 0.3 + max_batch_prefill_tokens = None + max_batch_total_tokens = None + max_waiting_tokens = 20 + max_batch_size = None + cuda_graphs = None + hostname = "0.0.0.0" + port = 3000 + shard_uds_path = "/tmp/text-generation-server" + master_addr = "localhost" + master_port = 29500 + huggingface_hub_cache = None + weights_cache_override = None + disable_custom_kernels = False + cuda_memory_fraction = 1.0 + rope_scaling = None + rope_factor = None + json_output = False + otlp_endpoint = None + cors_allow_origin = [] + watermark_gamma = None + watermark_delta = None + ngrok = False + ngrok_authtoken = None + ngrok_edge = None + tokenizer_config_path = None + disable_grammar_support = False + env = False + max_client_batch_size = 4 + + +class TGI(object): + # only allow a limited set of arguments for now + def __init__(self, model_id=None): + app_args = Args() + if model_id: + app_args.model_id = model_id + + print(asdict(app_args)) + self.thread = threading.Thread(target=self.run, args=(asdict(app_args),)) + self.thread.start() + + async def runit(self, args: dict): + print(args) + args = Args(**args) + try: + await fully_packaged( + args.model_id, + args.revision, + args.validation_workers, + args.sharded, + args.num_shard, + args.quantize, + args.speculate, + args.dtype, + args.trust_remote_code, + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens, + args.max_input_length, + args.max_total_tokens, + args.waiting_served_ratio, + args.max_batch_prefill_tokens, + args.max_batch_total_tokens, + args.max_waiting_tokens, + args.max_batch_size, + args.cuda_graphs, + args.hostname, + args.port, + args.shard_uds_path, + args.master_addr, + args.master_port, + args.huggingface_hub_cache, + args.weights_cache_override, + args.disable_custom_kernels, + args.cuda_memory_fraction, + args.rope_scaling, + args.rope_factor, + args.json_output, + args.otlp_endpoint, + args.cors_allow_origin, + args.watermark_gamma, + args.watermark_delta, + args.ngrok, + args.ngrok_authtoken, + args.ngrok_edge, + args.tokenizer_config_path, + args.disable_grammar_support, + args.env, + args.max_client_batch_size, + ) + except Exception as e: + print(e) + + def run(self, args: dict): + asyncio.run(self.runit(args)) + + def close(self): + self.thread.join()