Fix quantization defaults without cuda graphs on exl2 (linked to new

issues with it).
This commit is contained in:
Nicolas Patry 2024-08-15 09:12:21 +02:00
parent a041603462
commit 13350a330f
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 82 additions and 63 deletions

View File

@ -250,6 +250,8 @@ RUN cd server && \
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1
# Deps before the binaries

View File

@ -6,9 +6,6 @@ def flash_llama_exl2_handle(launcher):
with launcher(
"turboderp/Llama-3-8B-Instruct-exl2",
revision="2.5bpw",
# TODO
# Exl2 is currently broken with cuda graphs.
cuda_graphs=[0],
# Set max input length to avoid OOM due to extremely large
# scratch buffer.
max_input_length=1024,

View File

@ -30,11 +30,18 @@ struct RawConfig {
n_positions: Option<usize>,
model_type: Option<String>,
max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>,
}
#[derive(Deserialize)]
struct QuantizationConfig {
quant_method: Option<Quantization>,
}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
}
impl From<RawConfig> for Config {
@ -43,13 +50,16 @@ impl From<RawConfig> for Config {
.max_position_embeddings
.or(other.max_seq_len)
.or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
Config {
max_position_embeddings,
quantize,
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
#[serde(rename_all = "kebab-case")]
enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model:
/// <https://hf.co/models?search=awq>.
@ -72,10 +82,10 @@ enum Quantization {
Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16.
#[deprecated(
since = "1.1.0",
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
)]
// #[deprecated(
// since = "1.1.0",
// note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
// )]
Bitsandbytes,
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
/// but it is known that the model will be much slower to run than the native f16.
@ -1085,6 +1095,7 @@ fn spawn_shards(
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
max_input_tokens: usize,
quantize: Option<Quantization>,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
@ -1106,7 +1117,6 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code;
@ -1429,65 +1439,68 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args);
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let get_max_positions_quantize =
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
Api::new()?
path.push("config.json");
path
};
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("ATTENTION", "flashdecoding");
}
let config: Config = config.into();
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok(max_default)
} else {
Ok(max_position_embeddings)
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("ATTENTION", "flashdecoding");
}
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
}
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
let config: Config = config.into();
let quantize = config.quantize;
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok((max_default, quantize))
} else {
Ok((max_position_embeddings, quantize))
}
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
}
};
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
get_max_positions_quantize().unwrap_or((4096, None));
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
@ -1544,7 +1557,9 @@ fn main() -> Result<(), LauncherError> {
)));
}
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
let quantize = args.quantize.or(quantize);
tracing::info!("Quantize found {quantize:?}");
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)]
(
@ -1558,6 +1573,10 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
(None, Some(Quantization::Exl2)) => {
tracing::info!("Exl2 doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
@ -1672,6 +1691,7 @@ fn main() -> Result<(), LauncherError> {
cuda_graphs,
max_total_tokens,
max_input_tokens,
quantize,
max_log_level,
shutdown.clone(),
&shutdown_receiver,