fix speculate

This commit is contained in:
OlivierDehaene 2024-01-10 17:40:37 +01:00 committed by Nicolas Patry
parent ca20c304b3
commit 33e94379c8
2 changed files with 7 additions and 5 deletions

View File

@ -782,7 +782,9 @@ class FlashCausalLM(Model):
logger.info("Experimental support for Cuda Graphs is enabled") logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs for all power of twos until 64 # Warmup cuda graphs for all power of twos until 64
for i in range(6): for i in range(6):
self.cuda_graph_warmup(2**i, max_s, max_bt) bs = 2**i
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception: except Exception:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
@ -840,13 +842,13 @@ class FlashCausalLM(Model):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
bs = batch.input_ids.shape[0] bs = input_ids.shape[0]
# Ceil next power of two for batch size # Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
if batch.cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,

View File

@ -446,13 +446,13 @@ class BaseFlashMistral(FlashCausalLM):
if self.model.max_past is not None: if self.model.max_past is not None:
max_s = min(self.model.max_past, max_s) max_s = min(self.model.max_past, max_s)
bs = batch.input_ids.shape[0] bs = input_ids.shape[0]
# Ceil next power of two for batch size # Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
if batch.cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,