improve performance

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-11 06:10:17 -07:00
parent 76cc129796
commit ba049c9d49
3 changed files with 36 additions and 24 deletions

View File

@ -89,7 +89,7 @@ def get_sliding_windows() -> int:
def prepare_for_decode(
dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx
dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx
):
# Prepare values if we need to continue decoding
# need for HPUPagedAttentionMetadata preparation
@ -105,7 +105,7 @@ def prepare_for_decode(
padding = target_len - input_len
return input + [v] * padding
last_block_usage = slot % BLOCK_SIZE + 1
last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots]
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
@ -964,7 +964,7 @@ class FlashCausalLMBatch(Batch):
)
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
block_tables = []
for i, bt in enumerate(self.block_tables):
block_tables.append(bt[0 : block_num[i]])
@ -984,7 +984,7 @@ class FlashCausalLMBatch(Batch):
dtype,
use_contiguous_pa,
self.block_tables_tensor.device,
slots,
slots.cpu(),
block_tables,
padded_bs,
bucketing_ctx,
@ -1616,7 +1616,6 @@ class FlashCausalLM(Model):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
@ -1641,13 +1640,14 @@ class FlashCausalLM(Model):
batch_size,
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots,
slots=slots_tensor,
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
@ -1866,8 +1866,8 @@ class FlashCausalLM(Model):
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
batch.cache_lengths[i]
+ batch.input_lengths[i] : batch.cache_lengths[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
@ -1915,14 +1915,36 @@ class FlashCausalLM(Model):
}
)
idx = len(prev_batches) - 1
if batch.speculative_logits is not None:
accepted_ids_cpu = accepted_ids.cpu()
for req_idx, req in enumerate(batch.requests):
new_input_length = 1
if batch.speculative_logits is not None:
new_cache_length = (
batch.cache_lengths[req_idx]
+ batch.input_lengths[req_idx]
+ accepted_ids_cpu[req_idx]
- 1
)
else:
new_cache_length = (
batch.cache_lengths[req_idx] + batch.input_lengths[req_idx]
)
batch.cache_lengths[req_idx] = new_cache_length
batch.max_input_length = max(
batch.max_input_length, new_input_length
)
batch.input_lengths[req_idx] = new_input_length
current_length = new_cache_length + new_input_length
batch.max_current_length = max(
batch.max_current_length, current_length
)
requests_to_generate.append(
{
"idx": idx,
"request_id": req.id,
"cache_length": batch.cache_lengths[req_idx],
"input_length": batch.input_lengths[req_idx],
"prefix_offset": batch.prefix_offsets[req_idx],
"read_offset": batch.read_offsets[req_idx],
"stopping_criteria": batch.stopping_criterias[req_idx],
@ -2029,8 +2051,6 @@ class FlashCausalLM(Model):
for i, req_data in enumerate(requests_to_generate):
idx = req_data["idx"]
request_id = req_data["request_id"]
cache_length = req_data["cache_length"]
input_length = req_data["input_length"]
prefix_offset = req_data["prefix_offset"]
read_offset = req_data["read_offset"]
stopping_criteria = req_data["stopping_criteria"]
@ -2041,9 +2061,6 @@ class FlashCausalLM(Model):
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
top_token_ids = req_data["top_token_ids"]
top_token_logprobs = req_data["top_token_logprobs"]
new_input_length = 1
new_cache_length = cache_length + input_length + n_accepted_ids - 1
# Append next token to all tokens
next_token_texts = []
left = 0
@ -2159,11 +2176,6 @@ class FlashCausalLM(Model):
# Update values
indexs[idx] += n_accepted_ids
idx_accept_ids[idx] += 1
batch.cache_lengths[i] = new_cache_length
batch.max_input_length = max(batch.max_input_length, new_input_length)
batch.input_lengths[i] = new_input_length
current_length = new_cache_length + new_input_length
batch.max_current_length = max(batch.max_current_length, current_length)
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset

View File

@ -408,7 +408,6 @@ class FlashVlmCausalLM(FlashCausalLM):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
@ -433,13 +432,14 @@ class FlashVlmCausalLM(FlashCausalLM):
batch_size,
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots,
slots=slots_tensor,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,

View File

@ -247,7 +247,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
@ -279,12 +278,13 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, 1, False
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots,
slots=slots_tensor,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,