tunableop in warmup

This commit is contained in:
fxmarty 2024-04-19 09:09:16 +00:00
parent 3016e1595f
commit b503b3de60

View File

@ -807,6 +807,7 @@ class FlashCausalLM(Model):
self.device, self.device,
) )
logger.info("CUDA_GRAPHS", CUDA_GRAPHS)
if CUDA_GRAPHS: if CUDA_GRAPHS:
try: try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
@ -817,8 +818,39 @@ class FlashCausalLM(Model):
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
if IS_ROCM_SYSTEM and TUNABLEOP:
total_seqlens = list(range(16))
for seqlen in total_seqlens:
self.tunableop_warmup(seqlen, max_s, max_bt)
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
# TODO: is this correct?
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=None,
)
def forward( def forward(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: