mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix speculate
This commit is contained in:
parent
ca20c304b3
commit
33e94379c8
@ -782,7 +782,9 @@ class FlashCausalLM(Model):
|
|||||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||||
# Warmup cuda graphs for all power of twos until 64
|
# Warmup cuda graphs for all power of twos until 64
|
||||||
for i in range(6):
|
for i in range(6):
|
||||||
self.cuda_graph_warmup(2**i, max_s, max_bt)
|
bs = 2**i
|
||||||
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
|
||||||
@ -840,13 +842,13 @@ class FlashCausalLM(Model):
|
|||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
bs = batch.input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
# Ceil next power of two for batch size
|
# Ceil next power of two for batch size
|
||||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||||
|
|
||||||
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -446,13 +446,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
if self.model.max_past is not None:
|
if self.model.max_past is not None:
|
||||||
max_s = min(self.model.max_past, max_s)
|
max_s = min(self.model.max_past, max_s)
|
||||||
|
|
||||||
bs = batch.input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
# Ceil next power of two for batch size
|
# Ceil next power of two for batch size
|
||||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||||
|
|
||||||
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user