mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
improve performance
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
76cc129796
commit
ba049c9d49
@ -89,7 +89,7 @@ def get_sliding_windows() -> int:
|
||||
|
||||
|
||||
def prepare_for_decode(
|
||||
dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx
|
||||
dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx
|
||||
):
|
||||
# Prepare values if we need to continue decoding
|
||||
# need for HPUPagedAttentionMetadata preparation
|
||||
@ -105,7 +105,7 @@ def prepare_for_decode(
|
||||
padding = target_len - input_len
|
||||
return input + [v] * padding
|
||||
|
||||
last_block_usage = slot % BLOCK_SIZE + 1
|
||||
last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots]
|
||||
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
||||
block_usage = [
|
||||
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
|
||||
@ -964,7 +964,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
||||
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
|
||||
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
|
||||
block_tables = []
|
||||
for i, bt in enumerate(self.block_tables):
|
||||
block_tables.append(bt[0 : block_num[i]])
|
||||
@ -984,7 +984,7 @@ class FlashCausalLMBatch(Batch):
|
||||
dtype,
|
||||
use_contiguous_pa,
|
||||
self.block_tables_tensor.device,
|
||||
slots,
|
||||
slots.cpu(),
|
||||
block_tables,
|
||||
padded_bs,
|
||||
bucketing_ctx,
|
||||
@ -1616,7 +1616,6 @@ class FlashCausalLM(Model):
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.tensor(
|
||||
past_len, dtype=torch.int32, device=self.device
|
||||
@ -1641,13 +1640,14 @@ class FlashCausalLM(Model):
|
||||
batch_size,
|
||||
bucketing_ctx=None,
|
||||
)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
slots=slots,
|
||||
slots=slots_tensor,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
lm_head_indices=None,
|
||||
adapter_data=None,
|
||||
@ -1866,8 +1866,8 @@ class FlashCausalLM(Model):
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
i,
|
||||
batch.cache_lengths_tensor[i]
|
||||
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
||||
batch.cache_lengths[i]
|
||||
+ batch.input_lengths[i] : batch.cache_lengths[i]
|
||||
+ batch.input_lengths[i]
|
||||
+ accepted_ids[i],
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
@ -1915,14 +1915,36 @@ class FlashCausalLM(Model):
|
||||
}
|
||||
)
|
||||
idx = len(prev_batches) - 1
|
||||
if batch.speculative_logits is not None:
|
||||
accepted_ids_cpu = accepted_ids.cpu()
|
||||
|
||||
for req_idx, req in enumerate(batch.requests):
|
||||
new_input_length = 1
|
||||
if batch.speculative_logits is not None:
|
||||
new_cache_length = (
|
||||
batch.cache_lengths[req_idx]
|
||||
+ batch.input_lengths[req_idx]
|
||||
+ accepted_ids_cpu[req_idx]
|
||||
- 1
|
||||
)
|
||||
else:
|
||||
new_cache_length = (
|
||||
batch.cache_lengths[req_idx] + batch.input_lengths[req_idx]
|
||||
)
|
||||
batch.cache_lengths[req_idx] = new_cache_length
|
||||
batch.max_input_length = max(
|
||||
batch.max_input_length, new_input_length
|
||||
)
|
||||
batch.input_lengths[req_idx] = new_input_length
|
||||
current_length = new_cache_length + new_input_length
|
||||
batch.max_current_length = max(
|
||||
batch.max_current_length, current_length
|
||||
)
|
||||
|
||||
requests_to_generate.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"request_id": req.id,
|
||||
"cache_length": batch.cache_lengths[req_idx],
|
||||
"input_length": batch.input_lengths[req_idx],
|
||||
"prefix_offset": batch.prefix_offsets[req_idx],
|
||||
"read_offset": batch.read_offsets[req_idx],
|
||||
"stopping_criteria": batch.stopping_criterias[req_idx],
|
||||
@ -2029,8 +2051,6 @@ class FlashCausalLM(Model):
|
||||
for i, req_data in enumerate(requests_to_generate):
|
||||
idx = req_data["idx"]
|
||||
request_id = req_data["request_id"]
|
||||
cache_length = req_data["cache_length"]
|
||||
input_length = req_data["input_length"]
|
||||
prefix_offset = req_data["prefix_offset"]
|
||||
read_offset = req_data["read_offset"]
|
||||
stopping_criteria = req_data["stopping_criteria"]
|
||||
@ -2041,9 +2061,6 @@ class FlashCausalLM(Model):
|
||||
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
|
||||
top_token_ids = req_data["top_token_ids"]
|
||||
top_token_logprobs = req_data["top_token_logprobs"]
|
||||
|
||||
new_input_length = 1
|
||||
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
||||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
@ -2159,11 +2176,6 @@ class FlashCausalLM(Model):
|
||||
# Update values
|
||||
indexs[idx] += n_accepted_ids
|
||||
idx_accept_ids[idx] += 1
|
||||
batch.cache_lengths[i] = new_cache_length
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
batch.input_lengths[i] = new_input_length
|
||||
current_length = new_cache_length + new_input_length
|
||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
||||
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
|
@ -408,7 +408,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.tensor(
|
||||
past_len, dtype=torch.int32, device=self.device
|
||||
@ -433,13 +432,14 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
batch_size,
|
||||
bucketing_ctx=None,
|
||||
)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
slots=slots,
|
||||
slots=slots_tensor,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
lm_head_indices=None,
|
||||
|
@ -247,7 +247,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.tensor(
|
||||
past_len, dtype=torch.int32, device=self.device
|
||||
@ -279,12 +278,13 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
indices, cross_attention_len = generate_cross_attention_states(
|
||||
cross_attention_states, image_indices, seqlen, 1, False
|
||||
)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
slots=slots,
|
||||
slots=slots_tensor,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
lm_head_indices=None,
|
||||
|
Loading…
Reference in New Issue
Block a user