Remove unecessary cuda graph.

This commit is contained in:
Nicolas Patry 2024-03-21 20:09:28 +01:00
parent de6cb15fa5
commit 2e754ffd2e

View File

@ -802,7 +802,7 @@ class FlashCausalLM(Model):
try:
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception: