diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ab2d3313..cb777010 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -782,7 +782,9 @@ class FlashCausalLM(Model): logger.info("Experimental support for Cuda Graphs is enabled") # Warmup cuda graphs for all power of twos until 64 for i in range(6): - self.cuda_graph_warmup(2**i, max_s, max_bt) + bs = 2**i + if self.speculate is None or self.speculate + 1 <= bs: + self.cuda_graph_warmup(bs, max_s, max_bt) except Exception: logger.exception(f"Decode cuda graph warmup failed") @@ -840,13 +842,13 @@ class FlashCausalLM(Model): max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - bs = batch.input_ids.shape[0] + bs = input_ids.shape[0] # Ceil next power of two for batch size bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) - if batch.cu_seqlen_prefill is not None or cuda_graph is None: + if cu_seqlen_prefill is not None or cuda_graph is None: return self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index dee272a0..8d7e2a2b 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -446,13 +446,13 @@ class BaseFlashMistral(FlashCausalLM): if self.model.max_past is not None: max_s = min(self.model.max_past, max_s) - bs = batch.input_ids.shape[0] + bs = input_ids.shape[0] # Ceil next power of two for batch size bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) - if batch.cu_seqlen_prefill is not None or cuda_graph is None: + if cu_seqlen_prefill is not None or cuda_graph is None: logits = self.model.forward( input_ids=input_ids, position_ids=position_ids,