mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-02 23:42:06 +00:00
tunableop in warmup
This commit is contained in:
parent
3016e1595f
commit
b503b3de60
@ -807,6 +807,7 @@ class FlashCausalLM(Model):
|
||||
self.device,
|
||||
)
|
||||
|
||||
logger.info("CUDA_GRAPHS", CUDA_GRAPHS)
|
||||
if CUDA_GRAPHS:
|
||||
try:
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
@ -817,8 +818,39 @@ class FlashCausalLM(Model):
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
||||
if IS_ROCM_SYSTEM and TUNABLEOP:
|
||||
total_seqlens = list(range(16))
|
||||
for seqlen in total_seqlens:
|
||||
self.tunableop_warmup(seqlen, max_s, max_bt)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
# TODO: is this correct?
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
.reshape((bs, max_bt))
|
||||
)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
Loading…
Reference in New Issue
Block a user