mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Fix quantization defaults without cuda graphs on exl2 (linked to new
issues with it).
This commit is contained in:
parent
a041603462
commit
13350a330f
@ -250,6 +250,8 @@ RUN cd server && \
|
|||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
# This is needed because exl2 tries to load flash-attn
|
||||||
|
# And fails with our builds.
|
||||||
ENV EXLLAMA_NO_FLASH_ATTN=1
|
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||||
|
|
||||||
# Deps before the binaries
|
# Deps before the binaries
|
||||||
|
@ -6,9 +6,6 @@ def flash_llama_exl2_handle(launcher):
|
|||||||
with launcher(
|
with launcher(
|
||||||
"turboderp/Llama-3-8B-Instruct-exl2",
|
"turboderp/Llama-3-8B-Instruct-exl2",
|
||||||
revision="2.5bpw",
|
revision="2.5bpw",
|
||||||
# TODO
|
|
||||||
# Exl2 is currently broken with cuda graphs.
|
|
||||||
cuda_graphs=[0],
|
|
||||||
# Set max input length to avoid OOM due to extremely large
|
# Set max input length to avoid OOM due to extremely large
|
||||||
# scratch buffer.
|
# scratch buffer.
|
||||||
max_input_length=1024,
|
max_input_length=1024,
|
||||||
|
@ -30,11 +30,18 @@ struct RawConfig {
|
|||||||
n_positions: Option<usize>,
|
n_positions: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
max_seq_len: Option<usize>,
|
max_seq_len: Option<usize>,
|
||||||
|
quantization_config: Option<QuantizationConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct QuantizationConfig {
|
||||||
|
quant_method: Option<Quantization>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
|
quantize: Option<Quantization>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
@ -43,13 +50,16 @@ impl From<RawConfig> for Config {
|
|||||||
.max_position_embeddings
|
.max_position_embeddings
|
||||||
.or(other.max_seq_len)
|
.or(other.max_seq_len)
|
||||||
.or(other.n_positions);
|
.or(other.n_positions);
|
||||||
|
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||||
Config {
|
Config {
|
||||||
max_position_embeddings,
|
max_position_embeddings,
|
||||||
|
quantize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
|
||||||
|
#[serde(rename_all = "kebab-case")]
|
||||||
enum Quantization {
|
enum Quantization {
|
||||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||||
/// <https://hf.co/models?search=awq>.
|
/// <https://hf.co/models?search=awq>.
|
||||||
@ -72,10 +82,10 @@ enum Quantization {
|
|||||||
Marlin,
|
Marlin,
|
||||||
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
||||||
/// but it is known that the model will be much slower to run than the native f16.
|
/// but it is known that the model will be much slower to run than the native f16.
|
||||||
#[deprecated(
|
// #[deprecated(
|
||||||
since = "1.1.0",
|
// since = "1.1.0",
|
||||||
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
// note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
||||||
)]
|
// )]
|
||||||
Bitsandbytes,
|
Bitsandbytes,
|
||||||
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
|
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
|
||||||
/// but it is known that the model will be much slower to run than the native f16.
|
/// but it is known that the model will be much slower to run than the native f16.
|
||||||
@ -1085,6 +1095,7 @@ fn spawn_shards(
|
|||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
|
quantize: Option<Quantization>,
|
||||||
max_log_level: LevelFilter,
|
max_log_level: LevelFilter,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
@ -1106,7 +1117,6 @@ fn spawn_shards(
|
|||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||||
let otlp_service_name = args.otlp_service_name.clone();
|
let otlp_service_name = args.otlp_service_name.clone();
|
||||||
let quantize = args.quantize;
|
|
||||||
let speculate = args.speculate;
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
@ -1429,65 +1439,68 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
|
|
||||||
tracing::info!("{:#?}", args);
|
tracing::info!("{:#?}", args);
|
||||||
|
|
||||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
let get_max_positions_quantize =
|
||||||
let model_id = args.model_id.clone();
|
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
|
||||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
let model_id = args.model_id.clone();
|
||||||
let filename = if !path.exists() {
|
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||||
// Assume it's a hub id
|
let filename = if !path.exists() {
|
||||||
|
// Assume it's a hub id
|
||||||
|
|
||||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||||
// env variable has precedence over on file token.
|
// env variable has precedence over on file token.
|
||||||
ApiBuilder::new().with_token(Some(token)).build()?
|
ApiBuilder::new().with_token(Some(token)).build()?
|
||||||
|
} else {
|
||||||
|
Api::new()?
|
||||||
|
};
|
||||||
|
let repo = if let Some(ref revision) = args.revision {
|
||||||
|
api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
api.model(model_id)
|
||||||
|
};
|
||||||
|
repo.get("config.json")?
|
||||||
} else {
|
} else {
|
||||||
Api::new()?
|
path.push("config.json");
|
||||||
|
path
|
||||||
};
|
};
|
||||||
let repo = if let Some(ref revision) = args.revision {
|
|
||||||
api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
api.model(model_id)
|
|
||||||
};
|
|
||||||
repo.get("config.json")?
|
|
||||||
} else {
|
|
||||||
path.push("config.json");
|
|
||||||
path
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
if config.model_type == Some("gemma2".to_string()) {
|
if config.model_type == Some("gemma2".to_string()) {
|
||||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||||
std::env::set_var("ATTENTION", "flashdecoding");
|
std::env::set_var("ATTENTION", "flashdecoding");
|
||||||
}
|
|
||||||
let config: Config = config.into();
|
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
|
||||||
let max_default = 4096;
|
|
||||||
|
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
|
||||||
if max_position_embeddings > max_default {
|
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
Ok(max_default)
|
|
||||||
} else {
|
|
||||||
Ok(max_position_embeddings)
|
|
||||||
}
|
}
|
||||||
} else {
|
let config: Config = config.into();
|
||||||
Err(Box::new(LauncherError::ArgumentValidation(
|
let quantize = config.quantize;
|
||||||
"no max defined".to_string(),
|
|
||||||
)))
|
// Quantization usually means you're even more RAM constrained.
|
||||||
}
|
let max_default = 4096;
|
||||||
};
|
|
||||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
|
if max_position_embeddings > max_default {
|
||||||
|
let max = max_position_embeddings;
|
||||||
|
if args.max_input_tokens.is_none()
|
||||||
|
&& args.max_total_tokens.is_none()
|
||||||
|
&& args.max_batch_prefill_tokens.is_none()
|
||||||
|
{
|
||||||
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
|
}
|
||||||
|
Ok((max_default, quantize))
|
||||||
|
} else {
|
||||||
|
Ok((max_position_embeddings, quantize))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(Box::new(LauncherError::ArgumentValidation(
|
||||||
|
"no max defined".to_string(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
|
||||||
|
get_max_positions_quantize().unwrap_or((4096, None));
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
match (args.max_input_tokens, args.max_input_length) {
|
match (args.max_input_tokens, args.max_input_length) {
|
||||||
@ -1544,7 +1557,9 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
let quantize = args.quantize.or(quantize);
|
||||||
|
tracing::info!("Quantize found {quantize:?}");
|
||||||
|
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
||||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
(
|
(
|
||||||
@ -1558,6 +1573,10 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
|
(None, Some(Quantization::Exl2)) => {
|
||||||
|
tracing::info!("Exl2 doesn't work with cuda graphs, deactivating them");
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
|
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
|
||||||
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
|
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
|
||||||
@ -1672,6 +1691,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
cuda_graphs,
|
cuda_graphs,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
quantize,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
|
Loading…
Reference in New Issue
Block a user