Fix quantization defaults without cuda graphs on exl2 (linked to new

issues with it).
This commit is contained in:
Nicolas Patry 2024-08-15 09:12:21 +02:00
parent a041603462
commit 13350a330f
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 82 additions and 63 deletions

View File

@ -250,6 +250,8 @@ RUN cd server && \
pip install nvidia-nccl-cu12==2.22.3 pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1 ENV EXLLAMA_NO_FLASH_ATTN=1
# Deps before the binaries # Deps before the binaries

View File

@ -6,9 +6,6 @@ def flash_llama_exl2_handle(launcher):
with launcher( with launcher(
"turboderp/Llama-3-8B-Instruct-exl2", "turboderp/Llama-3-8B-Instruct-exl2",
revision="2.5bpw", revision="2.5bpw",
# TODO
# Exl2 is currently broken with cuda graphs.
cuda_graphs=[0],
# Set max input length to avoid OOM due to extremely large # Set max input length to avoid OOM due to extremely large
# scratch buffer. # scratch buffer.
max_input_length=1024, max_input_length=1024,

View File

@ -30,11 +30,18 @@ struct RawConfig {
n_positions: Option<usize>, n_positions: Option<usize>,
model_type: Option<String>, model_type: Option<String>,
max_seq_len: Option<usize>, max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>,
}
#[derive(Deserialize)]
struct QuantizationConfig {
quant_method: Option<Quantization>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct Config { struct Config {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
} }
impl From<RawConfig> for Config { impl From<RawConfig> for Config {
@ -43,13 +50,16 @@ impl From<RawConfig> for Config {
.max_position_embeddings .max_position_embeddings
.or(other.max_seq_len) .or(other.max_seq_len)
.or(other.n_positions); .or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
Config { Config {
max_position_embeddings, max_position_embeddings,
quantize,
} }
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
#[serde(rename_all = "kebab-case")]
enum Quantization { enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model: /// 4 bit quantization. Requires a specific AWQ quantized model:
/// <https://hf.co/models?search=awq>. /// <https://hf.co/models?search=awq>.
@ -72,10 +82,10 @@ enum Quantization {
Marlin, Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16. /// but it is known that the model will be much slower to run than the native f16.
#[deprecated( // #[deprecated(
since = "1.1.0", // since = "1.1.0",
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
)] // )]
Bitsandbytes, Bitsandbytes,
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
/// but it is known that the model will be much slower to run than the native f16. /// but it is known that the model will be much slower to run than the native f16.
@ -1085,6 +1095,7 @@ fn spawn_shards(
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize, max_total_tokens: usize,
max_input_tokens: usize, max_input_tokens: usize,
quantize: Option<Quantization>,
max_log_level: LevelFilter, max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
@ -1106,7 +1117,6 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone(); let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize;
let speculate = args.speculate; let speculate = args.speculate;
let dtype = args.dtype; let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
@ -1429,7 +1439,8 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args); tracing::info!("{:#?}", args);
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> { let get_max_positions_quantize =
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf(); let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() { let filename = if !path.exists() {
@ -1464,6 +1475,7 @@ fn main() -> Result<(), LauncherError> {
std::env::set_var("ATTENTION", "flashdecoding"); std::env::set_var("ATTENTION", "flashdecoding");
} }
let config: Config = config.into(); let config: Config = config.into();
let quantize = config.quantize;
// Quantization usually means you're even more RAM constrained. // Quantization usually means you're even more RAM constrained.
let max_default = 4096; let max_default = 4096;
@ -1477,9 +1489,9 @@ fn main() -> Result<(), LauncherError> {
{ {
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
} }
Ok(max_default) Ok((max_default, quantize))
} else { } else {
Ok(max_position_embeddings) Ok((max_position_embeddings, quantize))
} }
} else { } else {
Err(Box::new(LauncherError::ArgumentValidation( Err(Box::new(LauncherError::ArgumentValidation(
@ -1487,7 +1499,8 @@ fn main() -> Result<(), LauncherError> {
))) )))
} }
}; };
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
get_max_positions_quantize().unwrap_or((4096, None));
let max_input_tokens = { let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) { match (args.max_input_tokens, args.max_input_length) {
@ -1544,7 +1557,9 @@ fn main() -> Result<(), LauncherError> {
))); )));
} }
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { let quantize = args.quantize.or(quantize);
tracing::info!("Quantize found {quantize:?}");
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)] #[allow(deprecated)]
( (
@ -1558,6 +1573,10 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![] vec![]
} }
(None, Some(Quantization::Exl2)) => {
tracing::info!("Exl2 doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => { _ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}"); tracing::info!("Using default cuda graphs {cuda_graphs:?}");
@ -1672,6 +1691,7 @@ fn main() -> Result<(), LauncherError> {
cuda_graphs, cuda_graphs,
max_total_tokens, max_total_tokens,
max_input_tokens, max_input_tokens,
quantize,
max_log_level, max_log_level,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,