Fix disabling.

This commit is contained in:
Nicolas Patry 2024-04-04 13:01:27 +00:00
parent 6951962ffd
commit d67633a0c8

View File

@ -3,11 +3,12 @@ import os
MEM_POOL = torch.cuda.graph_pool_handle() MEM_POOL = torch.cuda.graph_pool_handle()
# This is overridden by the cli # This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS", "1,2,4,8,16,32,64,96,128") cuda_graphs = os.getenv("CUDA_GRAPHS")
try: if cuda_graphs is not None:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")] try:
except Exception as e: cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
raise RuntimeError( except Exception as e:
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" raise RuntimeError(
) f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs