From 33e94379c8534716f57b4e389d40c45db9defd62 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 10 Jan 2024 17:40:37 +0100 Subject: [PATCH] fix speculate --- server/text_generation_server/models/flash_causal_lm.py | 8 +++++--- server/text_generation_server/models/flash_mistral.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ab2d3313..cb777010 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -782,7 +782,9 @@ class FlashCausalLM(Model): logger.info("Experimental support for Cuda Graphs is enabled") # Warmup cuda graphs for all power of twos until 64 for i in range(6): - self.cuda_graph_warmup(2**i, max_s, max_bt) + bs = 2**i + if self.speculate is None or self.speculate + 1 <= bs: + self.cuda_graph_warmup(bs, max_s, max_bt) except Exception: logger.exception(f"Decode cuda graph warmup failed") @@ -840,13 +842,13 @@ class FlashCausalLM(Model): max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - bs = batch.input_ids.shape[0] + bs = input_ids.shape[0] # Ceil next power of two for batch size bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) - if batch.cu_seqlen_prefill is not None or cuda_graph is None: + if cu_seqlen_prefill is not None or cuda_graph is None: return self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index dee272a0..8d7e2a2b 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -446,13 +446,13 @@ class BaseFlashMistral(FlashCausalLM): if self.model.max_past is not None: max_s = min(self.model.max_past, max_s) - bs = batch.input_ids.shape[0] + bs = input_ids.shape[0] # Ceil next power of two for batch size bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) - if batch.cu_seqlen_prefill is not None or cuda_graph is None: + if cu_seqlen_prefill is not None or cuda_graph is None: logits = self.model.forward( input_ids=input_ids, position_ids=position_ids,