mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix speculate
This commit is contained in:
parent
ca20c304b3
commit
33e94379c8
@ -782,7 +782,9 @@ class FlashCausalLM(Model):
|
||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||
# Warmup cuda graphs for all power of twos until 64
|
||||
for i in range(6):
|
||||
self.cuda_graph_warmup(2**i, max_s, max_bt)
|
||||
bs = 2**i
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except Exception:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
||||
@ -840,13 +842,13 @@ class FlashCausalLM(Model):
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
bs = batch.input_ids.shape[0]
|
||||
bs = input_ids.shape[0]
|
||||
# Ceil next power of two for batch size
|
||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||
|
||||
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -446,13 +446,13 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
if self.model.max_past is not None:
|
||||
max_s = min(self.model.max_past, max_s)
|
||||
|
||||
bs = batch.input_ids.shape[0]
|
||||
bs = input_ids.shape[0]
|
||||
# Ceil next power of two for batch size
|
||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||
|
||||
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user