fix speculate

This commit is contained in:
OlivierDehaene 2024-01-10 17:40:37 +01:00 committed by Nicolas Patry
parent ca20c304b3
commit 33e94379c8
2 changed files with 7 additions and 5 deletions

View File

@ -782,7 +782,9 @@ class FlashCausalLM(Model):
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs for all power of twos until 64
for i in range(6):
self.cuda_graph_warmup(2**i, max_s, max_bt)
bs = 2**i
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
@ -840,13 +842,13 @@ class FlashCausalLM(Model):
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
bs = batch.input_ids.shape[0]
bs = input_ids.shape[0]
# Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
if cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,

View File

@ -446,13 +446,13 @@ class BaseFlashMistral(FlashCausalLM):
if self.model.max_past is not None:
max_s = min(self.model.max_past, max_s)
bs = batch.input_ids.shape[0]
bs = input_ids.shape[0]
# Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,