diff --git a/Cargo.lock b/Cargo.lock index 22eec927..539cf124 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2893,7 +2893,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "0.9.2" +version = "0.9.3" dependencies = [ "average", "clap", @@ -2913,7 +2913,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "0.9.2" +version = "0.9.3" dependencies = [ "futures", "grpc-metadata", @@ -2929,7 +2929,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "0.9.2" +version = "0.9.3" dependencies = [ "clap", "ctrlc", @@ -2945,7 +2945,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "0.9.2" +version = "0.9.3" dependencies = [ "async-stream", "axum", diff --git a/Cargo.toml b/Cargo.toml index 1383b7f9..49b7717a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.9.2" +version = "0.9.3" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/Dockerfile b/Dockerfile index 66e0091d..168f2f97 100644 --- a/Dockerfile +++ b/Dockerfile @@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile # Build specific version of flash attention RUN make build-flash-attention +# Build Flash Attention v2 CUDA kernels +FROM kernel-builder as flash-att-v2-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2 + # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder @@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages diff --git a/README.md b/README.md index fe55d7b5..43388d00 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,8 @@ to power LLMs api-inference widgets. - [Starcoder](https://huggingface.co/bigcode/starcoder) - [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b) - [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b) +- [MPT](https://huggingface.co/mosaicml/mpt-30b) +- [Llama V2](https://huggingface.co/meta-llama) Other architectures are supported on a best effort basis using: @@ -132,6 +134,10 @@ print(text) You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference). +### Using on private models or gated models + +You can use `HUGGING_FACE_HUB_TOKEN` environment variable to set the token used by `text-generation-inference` to give access to protected ressources. + ### Distributed Tracing `text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature @@ -211,7 +217,7 @@ sudo apt-get install libssl-dev gcc -y ### CUDA Kernels The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove -the kernels by using the `BUILD_EXTENSIONS=False` environment variable. +the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable. Be aware that the official Docker image has them enabled by default. diff --git a/docs/openapi.json b/docs/openapi.json index e570d92d..80240460 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "0.9.2" + "version": "0.9.3" }, "paths": { "/": { diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index ae0694da..3e7f86d4 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -13,7 +13,7 @@ nix = "0.26.2" serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.93" tracing = "0.1.37" -tracing-subscriber = { version = "0.3.16", features = ["json"] } +tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } [dev-dependencies] float_eq = "1.0.1" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d690a7c4..53de36b2 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -4,7 +4,7 @@ use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; -use std::io::{BufRead, BufReader, Read}; +use std::io::{BufRead, BufReader, Lines, Read}; use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; use std::process::{Child, Command, ExitStatus, Stdio}; @@ -15,6 +15,7 @@ use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; +use tracing_subscriber::EnvFilter; mod env_runtime; @@ -41,6 +42,7 @@ impl std::fmt::Display for Quantization { #[derive(Clone, Copy, Debug, ValueEnum)] enum Dtype { Float16, + #[clap(name = "bfloat16")] BFloat16, } @@ -182,8 +184,8 @@ struct Args { /// depends on other parameters like if you're using quantization, flash attention /// or the model implementation, text-generation-inference cannot infer this number /// automatically. - #[clap(default_value = "16000", long, env)] - max_batch_total_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, /// This setting defines how many tokens can be passed before forcing the waiting /// queries to be put on the batch (if the size of the batch allows for it). @@ -265,17 +267,9 @@ struct Args { #[clap(long, env)] ngrok_authtoken: Option, - /// ngrok domain name where the axum webserver will be available at + /// ngrok edge #[clap(long, env)] - ngrok_domain: Option, - - /// ngrok basic auth username - #[clap(long, env)] - ngrok_username: Option, - - /// ngrok basic auth password - #[clap(long, env)] - ngrok_password: Option, + ngrok_edge: Option, /// Display a lot of information about your runtime environment #[clap(long, short, action)] @@ -285,7 +279,7 @@ struct Args { #[derive(Debug)] enum ShardStatus { Ready, - Failed((usize, Option)), + Failed(usize), } #[allow(clippy::too_many_arguments)] @@ -310,6 +304,9 @@ fn shard_manager( shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, ) { + // Enter shard-manager tracing span + let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); + // Get UDS path let uds_string = format!("{uds_path}-{rank}"); let uds = Path::new(&uds_string); @@ -364,12 +361,6 @@ fn shard_manager( // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - // Use cuda allocator. It leads to less memory fragmentation - envs.push(( - "PYTORCH_CUDA_ALLOC_CONF".into(), - "backend:cudaMallocAsync".into(), - )); - // Torch Distributed Env vars envs.push(("RANK".into(), rank.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -423,7 +414,7 @@ fn shard_manager( } // Start process - tracing::info!("Starting shard {rank}"); + tracing::info!("Starting shard"); let mut p = match Command::new("text-generation-server") .args(shard_args) .envs(envs) @@ -437,30 +428,23 @@ fn shard_manager( if err.kind() == io::ErrorKind::NotFound { tracing::error!("text-generation-server not found in PATH"); tracing::error!("Please install it with `make install-server`") - } else { + } + { tracing::error!("{}", err); } - status_sender - .send(ShardStatus::Failed((rank, Some(err.to_string())))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } }; // Redirect STDOUT to the console let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); - let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); + let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); + //stdout tracing thread thread::spawn(move || { - // Enter shard-manager tracing span - let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); - for line in shard_stdout_reader.lines() { - // Parse loguru logs - if let Ok(log) = serde_json::from_str::(&line.unwrap()) { - log.trace(); - } - } + log_lines(shard_stdout_reader.lines()); }); let mut ready = false; @@ -469,30 +453,25 @@ fn shard_manager( loop { // Process exited if let Some(exit_status) = p.try_wait().unwrap() { - // We read stderr in another thread as it seems that `read_to_string` can block - // indefinitely in some cases + // We read stderr in another thread as it seems that lines() can block in some cases let (err_sender, err_receiver) = mpsc::channel(); thread::spawn(move || { - let mut err = String::new(); - shard_stderr_reader.read_to_string(&mut err).unwrap(); - err_sender.send(err).unwrap_or(()); + for line in shard_stderr_reader.lines().flatten() { + err_sender.send(line).unwrap_or(()); + } }); + let mut err = String::new(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } - let err = err_receiver - .recv_timeout(Duration::from_millis(100)) - .map_err(|err| { - tracing::error!("Unable to read shard {rank} error from stderr"); - err - }) - .ok(); + tracing::error!("Shard complete standard error output:\n{err}"); if let Some(signal) = exit_status.signal() { tracing::error!("Shard process was signaled to shutdown with signal {signal}"); } - status_sender - .send(ShardStatus::Failed((rank, err))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } @@ -500,17 +479,17 @@ fn shard_manager( if shutdown.load(Ordering::SeqCst) { p.kill().unwrap(); let _ = p.wait(); - tracing::info!("Shard {rank} terminated"); + tracing::info!("Shard terminated"); return; } // Shard is ready if uds.exists() && !ready { - tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); + tracing::info!("Shard ready in {:?}", start_time.elapsed()); status_sender.send(ShardStatus::Ready).unwrap(); ready = true; } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard {rank} to be ready..."); + tracing::info!("Waiting for shard to be ready..."); wait_time = Instant::now(); } sleep(Duration::from_millis(100)); @@ -579,6 +558,23 @@ impl PythonLogMessage { } } +impl TryFrom<&String> for PythonLogMessage { + type Error = serde_json::Error; + + fn try_from(value: &String) -> Result { + serde_json::from_str::(value) + } +} + +fn log_lines(lines: Lines) { + for line in lines.flatten() { + match PythonLogMessage::try_from(&line) { + Ok(log) => log.trace(), + Err(_) => tracing::debug!("{line}"), + } + } +} + fn find_num_shards( sharded: Option, num_shard: Option, @@ -632,6 +628,9 @@ enum LauncherError { } fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { + // Enter download tracing span + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + let mut download_args = vec![ "download-weights".to_string(), args.model_id.to_string(), @@ -693,6 +692,8 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L if err.kind() == io::ErrorKind::NotFound { tracing::error!("text-generation-server not found in PATH"); tracing::error!("Please install it with `make install-server`") + } else { + tracing::error!("{}", err); } return Err(LauncherError::DownloadError); @@ -701,16 +702,10 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Redirect STDOUT to the console let download_stdout = download_process.stdout.take().unwrap(); + let stdout = BufReader::new(download_stdout); + thread::spawn(move || { - // Enter download tracing span - let stdout = BufReader::new(download_stdout); - let _span = tracing::span!(tracing::Level::INFO, "download").entered(); - for line in stdout.lines() { - // Parse loguru logs - if let Ok(log) = serde_json::from_str::(&line.unwrap()) { - log.trace(); - } - } + log_lines(stdout.lines()); }); loop { @@ -815,11 +810,8 @@ fn spawn_shards( Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100)); } - Ok(ShardStatus::Failed((rank, err))) => { + Ok(ShardStatus::Failed(rank)) => { tracing::error!("Shard {rank} failed to start"); - if let Some(err) = err { - tracing::error!("{err}"); - } shutdown_shards(shutdown, shutdown_receiver); return Err(LauncherError::ShardCannotStart); } @@ -854,8 +846,6 @@ fn spawn_webserver( args.max_total_tokens.to_string(), "--max-batch-prefill-tokens".to_string(), args.max_batch_prefill_tokens.to_string(), - "--max-batch-total-tokens".to_string(), - args.max_batch_total_tokens.to_string(), "--waiting-served-ratio".to_string(), args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), @@ -872,6 +862,12 @@ fn spawn_webserver( args.model_id, ]; + // Model optional max batch total tokens + if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { + router_args.push("--max-batch-total-tokens".to_string()); + router_args.push(max_batch_total_tokens.to_string()); + } + // Model optional revision if let Some(ref revision) = args.revision { router_args.push("--revision".to_string()); @@ -896,26 +892,11 @@ fn spawn_webserver( // Ngrok if args.ngrok { - let authtoken = args.ngrok_authtoken.ok_or_else(|| { - tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling"); - LauncherError::WebserverCannotStart - })?; - router_args.push("--ngrok".to_string()); router_args.push("--ngrok-authtoken".to_string()); - router_args.push(authtoken); - - if let Some(domain) = args.ngrok_domain { - router_args.push("--ngrok-domain".to_string()); - router_args.push(domain); - } - - if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) { - router_args.push("--ngrok-username".to_string()); - router_args.push(username); - router_args.push("--ngrok-password".to_string()); - router_args.push(password); - } + router_args.push(args.ngrok_authtoken.unwrap()); + router_args.push("--ngrok-edge".to_string()); + router_args.push(args.ngrok_edge.unwrap()); } // Copy current process env @@ -993,12 +974,22 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R fn main() -> Result<(), LauncherError> { // Pattern match configuration - let args = Args::parse(); + let args: Args = Args::parse(); + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); if args.json_output { - tracing_subscriber::fmt().json().init(); + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .json() + .init(); } else { - tracing_subscriber::fmt().compact().init(); + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .compact() + .init(); } if args.env { @@ -1020,18 +1011,7 @@ fn main() -> Result<(), LauncherError> { args.max_batch_prefill_tokens, args.max_input_length ))); } - if args.max_batch_prefill_tokens > args.max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_batch_prefill_tokens, args.max_batch_total_tokens - ))); - } - if args.max_total_tokens as u32 > args.max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_total_tokens, args.max_batch_total_tokens - ))); - } + if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), @@ -1049,6 +1029,35 @@ fn main() -> Result<(), LauncherError> { tracing::info!("Sharding model on {num_shard} processes"); } + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if args.max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_batch_prefill_tokens, max_batch_total_tokens + ))); + } + if args.max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_total_tokens, max_batch_total_tokens + ))); + } + } + + if args.ngrok { + if args.ngrok_authtoken.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), + )); + } + + if args.ngrok_edge.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-edge` must be set when using ngrok tunneling".to_string(), + )); + } + } + // Signal handler let running = Arc::new(AtomicBool::new(true)); let r = running.clone(); @@ -1101,11 +1110,8 @@ fn main() -> Result<(), LauncherError> { let mut exit_code = Ok(()); while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { + if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { tracing::error!("Shard {rank} crashed"); - if let Some(err) = err { - tracing::error!("{err}"); - } exit_code = Err(LauncherError::ShardFailed); break; }; diff --git a/proto/generate.proto b/proto/generate.proto index 5e061941..57d79bca 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -198,9 +198,10 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on Batch batch = 1; - /// Maximum number of tokens that the client will send - uint32 max_total_tokens = 2; } /// Empty response -message WarmupResponse {} +message WarmupResponse { + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index b9607a5d..7753f307 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -103,8 +103,7 @@ impl Client { &mut self, max_input_length: u32, max_prefill_tokens: u32, - max_total_tokens: u32, - ) -> Result<()> { + ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); @@ -143,13 +142,9 @@ impl Client { max_tokens: 0, }; - let request = tonic::Request::new(WarmupRequest { - batch: Some(batch), - max_total_tokens, - }) - .inject_context(); - self.stub.warmup(request).await?.into_inner(); - Ok(()) + let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) } /// Generate one token for each request in the given batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 9dd173a0..6d146bc5 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -95,14 +95,11 @@ impl ShardedClient { &mut self, max_input_length: u32, max_prefill_tokens: u32, - max_total_tokens: u32, - ) -> Result<()> { + ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| { - Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) - }) + .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer.rs b/router/src/infer.rs index d0d22d3b..188ddc64 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -53,7 +53,7 @@ impl Infer { generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding); + let queue = Queue::new(requires_padding, 16); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/main.rs b/router/src/main.rs index 178c249c..059f8692 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -37,8 +37,8 @@ struct Args { waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] max_batch_prefill_tokens: u32, - #[clap(default_value = "16000", long, env)] - max_batch_total_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "0.0.0.0", long, env)] @@ -64,11 +64,7 @@ struct Args { #[clap(long, env)] ngrok_authtoken: Option, #[clap(long, env)] - ngrok_domain: Option, - #[clap(long, env)] - ngrok_username: Option, - #[clap(long, env)] - ngrok_password: Option, + ngrok_edge: Option, } fn main() -> Result<(), RouterError> { @@ -96,9 +92,7 @@ fn main() -> Result<(), RouterError> { cors_allow_origin, ngrok, ngrok_authtoken, - ngrok_domain, - ngrok_username, - ngrok_password, + ngrok_edge, } = args; // Validate args @@ -110,18 +104,22 @@ fn main() -> Result<(), RouterError> { if max_input_length as u32 > max_batch_prefill_tokens { return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); } - if max_batch_prefill_tokens > max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } + if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), )); } + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin @@ -210,14 +208,35 @@ fn main() -> Result<(), RouterError> { // Warmup model tracing::info!("Warming up model"); - sharded_client - .warmup( - max_input_length as u32, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) + let max_supported_batch_total_tokens = match sharded_client + .warmup(max_input_length as u32, max_batch_prefill_tokens) .await - .map_err(RouterError::Warmup)?; + .map_err(RouterError::Warmup)? + { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), + ); + tracing::warn!("Model does not support automatic max batch total tokens"); + max_batch_total_tokens + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + max_supported_batch_total_tokens + } + }; + tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Connected"); let addr = match hostname.parse() { @@ -240,7 +259,7 @@ fn main() -> Result<(), RouterError> { max_total_tokens, waiting_served_ratio, max_batch_prefill_tokens, - max_batch_total_tokens, + max_supported_batch_total_tokens, max_waiting_tokens, sharded_client, tokenizer, @@ -249,9 +268,7 @@ fn main() -> Result<(), RouterError> { cors_allow_origin, ngrok, ngrok_authtoken, - ngrok_domain, - ngrok_username, - ngrok_password, + ngrok_edge, ) .await?; Ok(()) diff --git a/router/src/queue.rs b/router/src/queue.rs index 48e483a1..2d8d6d1c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -33,12 +33,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool) -> Self { + pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(requires_padding, queue_receiver)); + tokio::spawn(queue_task(requires_padding, block_size, queue_receiver)); Self { queue_sender } } @@ -81,8 +81,12 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(requires_padding: bool, receiver: flume::Receiver) { - let mut state = State::new(requires_padding); +async fn queue_task( + requires_padding: bool, + block_size: u32, + receiver: flume::Receiver, +) { + let mut state = State::new(requires_padding, block_size); while let Ok(cmd) = receiver.recv_async().await { match cmd { @@ -119,15 +123,19 @@ struct State { /// Whether the model is using padding requires_padding: bool, + + /// Paged Attention block size + block_size: u32, } impl State { - fn new(requires_padding: bool) -> Self { + fn new(requires_padding: bool, block_size: u32) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, + block_size, } } @@ -187,10 +195,21 @@ impl State { max_input_length = max_input_length.max(entry.request.input_length); prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { - prefill_tokens += entry.request.input_length; + // pad to block size + prefill_tokens += ((entry.request.input_length + self.block_size - 1) + / self.block_size) + * self.block_size; } - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + if self.requires_padding { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } else { + // pad to block size + decode_tokens += + ((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1) + / self.block_size) + * self.block_size; + } if prefill_tokens > prefill_token_budget || (prefill_tokens + decode_tokens) > token_budget @@ -321,7 +340,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -337,7 +356,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false); + let mut state = State::new(false, 1); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -345,7 +364,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -377,7 +396,7 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -410,14 +429,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -425,7 +444,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -458,7 +477,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -483,7 +502,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index 8ca463c2..bfeee375 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -524,9 +524,7 @@ pub async fn run( allow_origin: Option, ngrok: bool, ngrok_authtoken: Option, - ngrok_domain: Option, - ngrok_username: Option, - ngrok_password: Option, + ngrok_edge: Option, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -696,32 +694,25 @@ pub async fn run( #[cfg(feature = "ngrok")] { use ngrok::config::TunnelBuilder; - use ngrok::tunnel::UrlTunnel; let _ = addr; let authtoken = ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); - let mut tunnel = ngrok::Session::builder() + let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling"); + + let tunnel = ngrok::Session::builder() .authtoken(authtoken) .connect() .await .unwrap() - .http_endpoint(); - - if let Some(domain) = ngrok_domain { - tunnel = tunnel.domain(domain); - } - - if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) { - tunnel = tunnel.basic_auth(username, password); - } + .labeled_tunnel() + .label("edge", edge); let listener = tunnel.listen().await.unwrap(); // Run server - tracing::info!("Ingress URL: {:?}", listener.url()); axum::Server::builder(listener) .serve(app.into_make_service()) //Wait until all requests are finished to shut down diff --git a/server/Makefile b/server/Makefile index d0086928..0dc0b5c9 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,5 @@ include Makefile-flash-att +include Makefile-flash-att-v2 include Makefile-vllm unit-tests: diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 new file mode 100644 index 00000000..a7d63356 --- /dev/null +++ b/server/Makefile-flash-att-v2 @@ -0,0 +1,13 @@ +flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc + +flash-attention-v2: + # Clone flash attention + pip install packaging + git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + +build-flash-attention-v2: flash-attention-v2 + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) + cd flash-attention-v2 && python setup.py build + +install-flash-attention-v2: build-flash-attention-v2 + cd flash-attention-v2 && python setup.py install \ No newline at end of file diff --git a/server/pyproject.toml b/server/pyproject.toml index 57c72371..be79da51 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "0.9.2" +version = "0.9.3" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 7a55e919..e74c0331 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -194,6 +194,8 @@ def quantize( percdamp: float = 0.01, act_order: bool = False, ): + if revision is None: + revision = "main" download_weights( model_id=model_id, revision=revision, @@ -207,6 +209,7 @@ def quantize( bits=4, groupsize=128, output_dir=output_dir, + revision=revision, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fd97f8b1..ffc224cc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -42,51 +42,21 @@ __all__ = [ "get_model", ] -FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA and Flash Attention kernels to be installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" -) +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." +FLASH_ATTENTION = True try: - if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - if not torch.cuda.is_available(): - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA. No compatible CUDA devices found." - ) - raise ImportError("CUDA is not available") - - major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - - supported = is_sm75 or is_sm8x or is_sm90 - if not supported: - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. " - "No compatible CUDA device found." - ) - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) - - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, - ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, - ) - - FLASH_ATTENTION = True - else: - FLASH_ATTENTION = False -except ImportError: - logger.opt(exception=True).warning( - "Could not import Flash Attention enabled models" + from text_generation_server.models.flash_rw import FlashRWSharded + from text_generation_server.models.flash_neox import FlashNeoXSharded + from text_generation_server.models.flash_llama import ( + FlashLlama, ) + from text_generation_server.models.flash_santacoder import ( + FlashSantacoderSharded, + ) + +except ImportError as e: + logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d9f3c7b8..b2bde282 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -23,25 +23,77 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_cuda import dropout_layer_norm # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, TensorParallelHead, + get_linear, ) +class LlamaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + class LlamaRMSNorm(nn.Module): def __init__(self, prefix, weights, eps=1e-6): """ @@ -59,7 +111,8 @@ class LlamaRMSNorm(nn.Module): hidden_states += residual residual = hidden_states - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt( variance + self.variance_epsilon ) @@ -94,6 +147,27 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res +def _load_gqa(config, prefix: str, weights): + w = [ + weights.get_sharded(f"{prefix}.q_proj.weight", dim=0), + weights.get_sharded(f"{prefix}.k_proj.weight", dim=0), + weights.get_sharded(f"{prefix}.v_proj.weight", dim=0), + ] + weight = torch.cat(w, dim=0) + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + bias = None + assert config.hidden_size % config.num_attention_heads == 0 + head_size = config.hidden_size // config.num_attention_heads + assert config.num_attention_heads % weights.process_group.size() == 0 + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -118,22 +192,29 @@ class FlashLlamaAttention(torch.nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.query_key_value = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() ) + if config.num_attention_heads != config.num_key_value_heads: + self.query_key_value = _load_gqa(config, prefix, weights) + else: + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( - 0, self.num_heads, dtype=torch.int32, device=weights.device - ) + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -148,38 +229,37 @@ class FlashLlamaAttention(torch.nn.Module): max_s, ): qkv = self.query_key_value(hidden_states) - qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) vllm_cache_ops.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) # output tensor - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: @@ -187,7 +267,7 @@ class FlashLlamaAttention(torch.nn.Module): block_size = kv_cache[1].shape[3] vllm_attention_ops.single_query_cached_kv_attention( attn_output, - qkv[:, 0], + query, kv_cache[0], kv_cache[1], self.kv_head_mapping, @@ -324,6 +404,7 @@ class FlashLlamaModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b2dce226..e7c8ced4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index acac2744..1e9539c4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - if self.num_heads_kv == 1: - # Expand to query shape - kv = kv.expand(-1, 2, self.num_heads, self.head_size) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: @@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand to query shape - kv = ( - kv.unsqueeze(2) - .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) - .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) - ) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a19623a5..6f5c60fc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,13 +5,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -271,26 +269,15 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand from 1 to num_heads - key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d034d472..517fba68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -710,14 +710,14 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch - def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): + def warmup(self, batch: FlashCausalLMBatch): global CACHE_MANAGER torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(self.device) try: CACHE_MANAGER = CacheManager( - # Adds some wiggle room - math.ceil(max_total_tokens / BLOCK_SIZE) + 10, + batch.blocks, self.num_layers, self.num_kv_heads, self.head_size, @@ -727,11 +727,43 @@ class FlashCausalLM(Model): _, batch = self.generate_token(batch) except Exception as e: raise RuntimeError( - f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " - f"prefill tokens. " - f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" ) from e + + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize(self.device) + peak_memory = torch.cuda.max_memory_reserved(self.device) + + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + + # 0.98 to add some wiggle room + num_blocks = ( + int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size) + # Add batch.blocks as we allocated it above, so it is included in the peak memory. + + batch.blocks + ) + + del CACHE_MANAGER del batch + torch.cuda.empty_cache() + + CACHE_MANAGER = CacheManager( + num_blocks, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + + return int(num_blocks * BLOCK_SIZE) def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -991,7 +1023,6 @@ class FlashCausalLM(Model): if stopped: del batch - torch.cuda.empty_cache() # No need to return a batch if we know that all requests stopped return generations, None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 11c77e14..29a24816 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -9,6 +9,7 @@ from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, + LlamaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -52,7 +53,7 @@ class FlashLlama(FlashCausalLM): trust_remote_code=trust_remote_code, ) - config = AutoConfig.from_pretrained( + config = LlamaConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) @@ -69,7 +70,7 @@ class FlashLlama(FlashCausalLM): model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, + num_kv_heads=model.model.num_key_value_heads, head_size=model.model.head_size, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index f8460fc2..3827197f 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -58,8 +58,9 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError - def warmup(self, batch: B, max_total_tokens: int): + def warmup(self, batch: B) -> Optional[int]: self.generate_token(batch) + return None def decode_token( self, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7bc62ce6..e0efbcf5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - self.model.warmup(batch, request.max_total_tokens) + max_supported_total_tokens = self.model.warmup(batch) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return generate_pb2.WarmupResponse() + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( @@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: batch = self.model.batch_type.concatenate(batches) - if torch.cuda.is_available(): - torch.cuda.empty_cache() else: batch = batches[0] diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py new file mode 100644 index 00000000..c472d1fc --- /dev/null +++ b/server/text_generation_server/utils/flash_attn.py @@ -0,0 +1,124 @@ +import os +import torch + +from loguru import logger + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") + +if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +is_sm8x = major == 8 and minor >= 0 +is_sm90 = major == 9 and minor == 0 + +HAS_FLASH_ATTN = False +HAS_FLASH_ATTN_V2 = False +try: + try: + import flash_attn_2_cuda + except ImportError: + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2 = True +except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, +): + if HAS_FLASH_ATTN_V2: + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + + if HAS_FLASH_ATTN: + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) + + raise NotImplementedError("flash attention is not installed") diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 5a4ed8da..d182456f 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -13,6 +13,9 @@ import transformers from huggingface_hub import HfApi import numpy as np import torch +from accelerate import init_empty_weights +from text_generation_server.utils import initialize_torch_distributed, Weights +from text_generation_server.utils.hub import weight_files from text_generation_server.utils.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional @@ -38,7 +41,6 @@ class Quantizer(nn.Module): maxshrink=0.8, trits=False, ): - self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -600,6 +602,8 @@ def sequential( nsamples, bits, groupsize, + *, + hooks, percdamp=0.01, sym: bool = False, act_order: bool = False, @@ -637,7 +641,7 @@ def sequential( layers[0] = Catcher(layers[0]) for batch in dataloader: try: - model(batch[0]) + model(batch[0].cuda()) except ValueError: pass layers[0] = layers[0].module @@ -646,6 +650,8 @@ def sequential( # model.model.embed_tokens = model.model.embed_tokens.cpu() # model.model.norm = model.model.norm.cpu() torch.cuda.empty_cache() + for hook in hooks: + hook.remove() outs = torch.zeros_like(inps) @@ -662,10 +668,8 @@ def sequential( print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") print("+==================+==============+============+===========+=======+") - from accelerate.hooks import remove_hook_from_submodules - - layer = layers[i].to(dev) - remove_hook_from_submodules(layer) + layer = layers[i] + layer.load() full = find_layers(layer) sequential = [list(full.keys())] @@ -677,6 +681,7 @@ def sequential( gptq[name].quantizer.configure( bits, perchannel=True, sym=sym, mse=False ) + pass def add_batch(name): def tmp(_, inp, out): @@ -688,7 +693,6 @@ def sequential( for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] for h in handles: h.remove() @@ -714,7 +718,7 @@ def sequential( for j in range(nsamples): outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] - layers[i] = layer.cpu() + layer.unload() del layer del gptq torch.cuda.empty_cache() @@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize): return model +def setdeepattr(module, full_name, tensor): + current = module + tokens = full_name.split(".") + for token in tokens[:-1]: + current = getattr(current, token) + setattr(current, tokens[-1], tensor) + + +def getdeepattr(module, full_name): + current = module + tokens = full_name.split(".") + for token in tokens: + current = getattr(current, token) + return current + + +def load_weights_pre_hook(module_name, weights, recursive=False): + def inner(module, args): + print(f"Pre hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + + for local_param in local_params: + current_tensor = getdeepattr(module, local_param) + if current_tensor.device == torch.device("meta"): + # print(f"Loading {local_param}") + if module_name: + tensor_name = f"{module_name}.{local_param}" + else: + tensor_name = local_param + tensor = weights.get_tensor(tensor_name) + setdeepattr(module, local_param, nn.Parameter(tensor)) + else: + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), + ) + + return inner + + +def load_weights_post_hook(module_name, weights, recursive=False): + def inner(module, args, output): + print(f"Post hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for local_param in local_params: + # print(f"Unloading {local_param}") + current_tensor = getdeepattr(module, local_param) + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cpu"))), + ) + return output + + return inner + + def quantize( model_id: str, bits: int, groupsize: int, output_dir: str, + revision: str, trust_remote_code: bool, upload_to_model_id: Optional[str], percdamp: float, act_order: bool, ): print("loading model") - model = AutoModelForCausalLM.from_pretrained( + config = AutoConfig.from_pretrained( model_id, - torch_dtype=torch.float16, - device_map="balanced_low_0", trust_remote_code=trust_remote_code, ) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) + model = model.eval() + print("LOADED model") + files = weight_files(model_id, revision, extension=".safetensors") + process_group, _, _ = initialize_torch_distributed() + weights = Weights( + files, + device=torch.device("cuda:0"), + dtype=torch.float16, + process_group=process_group, + aliases={"embed_tokens.weight": ["lm_head.weight"]}, + ) + hooks = [] + for name, module in model.named_modules(): + + def load(module, name): + def _load(): + load_weights_pre_hook(name, weights, recursive=True)(module, None) + + return _load + + def unload(module, name): + def _unload(): + load_weights_post_hook(name, weights, recursive=True)( + module, None, None + ) + + return _unload + + module.load = load(module, name) + module.unload = unload(module, name) + hooks.append( + module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) + ) + hooks.append( + module.register_forward_hook(load_weights_post_hook(name, weights)) + ) model.seqlen = 2048 dataset = "wikitext2" @@ -806,6 +922,7 @@ def quantize( groupsize, percdamp=percdamp, act_order=act_order, + hooks=hooks, ) print(time.time() - tick) @@ -858,7 +975,6 @@ def quantize( logger.info("Saved tokenizer") if upload_to_model_id: - api = HfApi() api.upload_folder(