From 13350a330f21db0ca919b9c64297dba32bd108ce Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Aug 2024 09:12:21 +0200 Subject: [PATCH] Fix quantization defaults without cuda graphs on exl2 (linked to new issues with it). --- Dockerfile | 2 + .../models/test_flash_llama_exl2.py | 3 - launcher/src/main.rs | 140 ++++++++++-------- 3 files changed, 82 insertions(+), 63 deletions(-) diff --git a/Dockerfile b/Dockerfile index 60dfde75..b2d274d7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -250,6 +250,8 @@ RUN cd server && \ pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 +# This is needed because exl2 tries to load flash-attn +# And fails with our builds. ENV EXLLAMA_NO_FLASH_ATTN=1 # Deps before the binaries diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index b0bddf17..18319f60 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -6,9 +6,6 @@ def flash_llama_exl2_handle(launcher): with launcher( "turboderp/Llama-3-8B-Instruct-exl2", revision="2.5bpw", - # TODO - # Exl2 is currently broken with cuda graphs. - cuda_graphs=[0], # Set max input length to avoid OOM due to extremely large # scratch buffer. max_input_length=1024, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a64b1d71..58abb306 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -30,11 +30,18 @@ struct RawConfig { n_positions: Option, model_type: Option, max_seq_len: Option, + quantization_config: Option, +} + +#[derive(Deserialize)] +struct QuantizationConfig { + quant_method: Option, } #[derive(Deserialize)] struct Config { max_position_embeddings: Option, + quantize: Option, } impl From for Config { @@ -43,13 +50,16 @@ impl From for Config { .max_position_embeddings .or(other.max_seq_len) .or(other.n_positions); + let quantize = other.quantization_config.and_then(|q| q.quant_method); Config { max_position_embeddings, + quantize, } } } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)] +#[serde(rename_all = "kebab-case")] enum Quantization { /// 4 bit quantization. Requires a specific AWQ quantized model: /// . @@ -72,10 +82,10 @@ enum Quantization { Marlin, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// but it is known that the model will be much slower to run than the native f16. - #[deprecated( - since = "1.1.0", - note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" - )] + // #[deprecated( + // since = "1.1.0", + // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + // )] Bitsandbytes, /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, /// but it is known that the model will be much slower to run than the native f16. @@ -1085,6 +1095,7 @@ fn spawn_shards( cuda_graphs: Vec, max_total_tokens: usize, max_input_tokens: usize, + quantize: Option, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1106,7 +1117,6 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_service_name = args.otlp_service_name.clone(); - let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; @@ -1429,65 +1439,68 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_max_position_embeddings = || -> Result> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id + let get_max_positions_quantize = + || -> Result<(usize, Option), Box> { + let model_id = args.model_id.clone(); + let mut path = std::path::Path::new(&args.model_id).to_path_buf(); + let filename = if !path.exists() { + // Assume it's a hub id - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = args.revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? } else { - Api::new()? + path.push("config.json"); + path }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; - if config.model_type == Some("gemma2".to_string()) { - tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - let config: Config = config.into(); - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok(max_default) - } else { - Ok(max_position_embeddings) + if config.model_type == Some("gemma2".to_string()) { + tracing::info!("Forcing flash decoding because of softcap usage"); + std::env::set_var("ATTENTION", "flashdecoding"); } - } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) - } - }; - let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); + let config: Config = config.into(); + let quantize = config.quantize; + + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + Ok((max_default, quantize)) + } else { + Ok((max_position_embeddings, quantize)) + } + } else { + Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))) + } + }; + let (max_position_embeddings, quantize): (usize, Option) = + get_max_positions_quantize().unwrap_or((4096, None)); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { @@ -1544,7 +1557,9 @@ fn main() -> Result<(), LauncherError> { ))); } - let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { + let quantize = args.quantize.or(quantize); + tracing::info!("Quantize found {quantize:?}"); + let cuda_graphs = match (&args.cuda_graphs, &quantize) { (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] ( @@ -1558,6 +1573,10 @@ fn main() -> Result<(), LauncherError> { tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); vec![] } + (None, Some(Quantization::Exl2)) => { + tracing::info!("Exl2 doesn't work with cuda graphs, deactivating them"); + vec![] + } _ => { let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; tracing::info!("Using default cuda graphs {cuda_graphs:?}"); @@ -1672,6 +1691,7 @@ fn main() -> Result<(), LauncherError> { cuda_graphs, max_total_tokens, max_input_tokens, + quantize, max_log_level, shutdown.clone(), &shutdown_receiver,