This commit is contained in:
fxmarty 2024-05-17 16:03:15 +00:00
parent 422bf1f986
commit cd3c28cfe7
2 changed files with 7 additions and 3 deletions

View File

@ -827,7 +827,7 @@ class FlashCausalLM(Model):
self.device, self.device,
) )
if SYSTEM == "rocm": if SYSTEM == "rocm" and self.speculate is None or self.speculate == 0:
if ( if (
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
@ -875,7 +875,11 @@ class FlashCausalLM(Model):
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
# Warmup cuda graphs # Warmup cuda graphs
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs: if (
self.speculate is None
or self.speculate == 0
or self.speculate + 1 <= bs
):
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")