diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2c440083..e474f9d6 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -807,6 +807,7 @@ class FlashCausalLM(Model): self.device, ) + logger.info("CUDA_GRAPHS", CUDA_GRAPHS) if CUDA_GRAPHS: try: logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") @@ -817,8 +818,39 @@ class FlashCausalLM(Model): except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") + if IS_ROCM_SYSTEM and TUNABLEOP: + total_seqlens = list(range(16)) + for seqlen in total_seqlens: + self.tunableop_warmup(seqlen, max_s, max_bt) + return int(num_blocks * BLOCK_SIZE) + def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int): + input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) + + # TODO: is this correct? + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s + block_tables = ( + torch.arange(max_bt, dtype=torch.int32, device=self.device) + .repeat(bs) + .reshape((bs, max_bt)) + ) + kv_cache = get_cache_manager().kv_cache + + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=None, + ) + def forward( self, batch: FlashCausalLMBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: