This commit is contained in:
OlivierDehaene 2024-01-10 17:17:48 +01:00 committed by Nicolas Patry
parent 8260dc00d8
commit ca20c304b3

View File

@ -779,6 +779,7 @@ class FlashCausalLM(Model):
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
try:
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs for all power of twos until 64
for i in range(6):
self.cuda_graph_warmup(2**i, max_s, max_bt)