now working

This commit is contained in:
fxmarty 2024-04-19 12:08:39 +02:00
parent 24d43c487e
commit 804068c207
4 changed files with 10 additions and 2 deletions

View File

@ -1373,6 +1373,7 @@ fn main() -> Result<(), LauncherError> {
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
(Some(cuda_graphs), None) => cuda_graphs.clone(),
#[allow(deprecated)]
(
None,
@ -1385,7 +1386,7 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
(None, _) => {
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs

View File

@ -816,6 +816,8 @@ class FlashCausalLM(Model):
self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return int(num_blocks * BLOCK_SIZE)

View File

@ -4,11 +4,14 @@ import os
MEM_POOL = torch.cuda.graph_pool_handle()
# This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
if cuda_graphs is not None and cuda_graphs != "0":
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
except Exception as e:
raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs

View File

@ -474,6 +474,8 @@ class Mamba(Model):
self.cuda_graph_warmup(bs)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None