mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
improve performance
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
76cc129796
commit
ba049c9d49
@ -89,7 +89,7 @@ def get_sliding_windows() -> int:
|
|||||||
|
|
||||||
|
|
||||||
def prepare_for_decode(
|
def prepare_for_decode(
|
||||||
dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx
|
dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx
|
||||||
):
|
):
|
||||||
# Prepare values if we need to continue decoding
|
# Prepare values if we need to continue decoding
|
||||||
# need for HPUPagedAttentionMetadata preparation
|
# need for HPUPagedAttentionMetadata preparation
|
||||||
@ -105,7 +105,7 @@ def prepare_for_decode(
|
|||||||
padding = target_len - input_len
|
padding = target_len - input_len
|
||||||
return input + [v] * padding
|
return input + [v] * padding
|
||||||
|
|
||||||
last_block_usage = slot % BLOCK_SIZE + 1
|
last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots]
|
||||||
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
||||||
block_usage = [
|
block_usage = [
|
||||||
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
|
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
|
||||||
@ -964,7 +964,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
||||||
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
|
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
|
||||||
block_tables = []
|
block_tables = []
|
||||||
for i, bt in enumerate(self.block_tables):
|
for i, bt in enumerate(self.block_tables):
|
||||||
block_tables.append(bt[0 : block_num[i]])
|
block_tables.append(bt[0 : block_num[i]])
|
||||||
@ -984,7 +984,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype,
|
dtype,
|
||||||
use_contiguous_pa,
|
use_contiguous_pa,
|
||||||
self.block_tables_tensor.device,
|
self.block_tables_tensor.device,
|
||||||
slots,
|
slots.cpu(),
|
||||||
block_tables,
|
block_tables,
|
||||||
padded_bs,
|
padded_bs,
|
||||||
bucketing_ctx,
|
bucketing_ctx,
|
||||||
@ -1616,7 +1616,6 @@ class FlashCausalLM(Model):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
cache_lengths_tensor = torch.tensor(
|
cache_lengths_tensor = torch.tensor(
|
||||||
past_len, dtype=torch.int32, device=self.device
|
past_len, dtype=torch.int32, device=self.device
|
||||||
@ -1641,13 +1640,14 @@ class FlashCausalLM(Model):
|
|||||||
batch_size,
|
batch_size,
|
||||||
bucketing_ctx=None,
|
bucketing_ctx=None,
|
||||||
)
|
)
|
||||||
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=slots_tensor,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
@ -1866,8 +1866,8 @@ class FlashCausalLM(Model):
|
|||||||
for i in range(len(batch)):
|
for i in range(len(batch)):
|
||||||
batch.all_input_ids_tensor[
|
batch.all_input_ids_tensor[
|
||||||
i,
|
i,
|
||||||
batch.cache_lengths_tensor[i]
|
batch.cache_lengths[i]
|
||||||
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
+ batch.input_lengths[i] : batch.cache_lengths[i]
|
||||||
+ batch.input_lengths[i]
|
+ batch.input_lengths[i]
|
||||||
+ accepted_ids[i],
|
+ accepted_ids[i],
|
||||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||||
@ -1915,14 +1915,36 @@ class FlashCausalLM(Model):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
idx = len(prev_batches) - 1
|
idx = len(prev_batches) - 1
|
||||||
|
if batch.speculative_logits is not None:
|
||||||
|
accepted_ids_cpu = accepted_ids.cpu()
|
||||||
|
|
||||||
for req_idx, req in enumerate(batch.requests):
|
for req_idx, req in enumerate(batch.requests):
|
||||||
|
new_input_length = 1
|
||||||
|
if batch.speculative_logits is not None:
|
||||||
|
new_cache_length = (
|
||||||
|
batch.cache_lengths[req_idx]
|
||||||
|
+ batch.input_lengths[req_idx]
|
||||||
|
+ accepted_ids_cpu[req_idx]
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_cache_length = (
|
||||||
|
batch.cache_lengths[req_idx] + batch.input_lengths[req_idx]
|
||||||
|
)
|
||||||
|
batch.cache_lengths[req_idx] = new_cache_length
|
||||||
|
batch.max_input_length = max(
|
||||||
|
batch.max_input_length, new_input_length
|
||||||
|
)
|
||||||
|
batch.input_lengths[req_idx] = new_input_length
|
||||||
|
current_length = new_cache_length + new_input_length
|
||||||
|
batch.max_current_length = max(
|
||||||
|
batch.max_current_length, current_length
|
||||||
|
)
|
||||||
|
|
||||||
requests_to_generate.append(
|
requests_to_generate.append(
|
||||||
{
|
{
|
||||||
"idx": idx,
|
"idx": idx,
|
||||||
"request_id": req.id,
|
"request_id": req.id,
|
||||||
"cache_length": batch.cache_lengths[req_idx],
|
|
||||||
"input_length": batch.input_lengths[req_idx],
|
|
||||||
"prefix_offset": batch.prefix_offsets[req_idx],
|
"prefix_offset": batch.prefix_offsets[req_idx],
|
||||||
"read_offset": batch.read_offsets[req_idx],
|
"read_offset": batch.read_offsets[req_idx],
|
||||||
"stopping_criteria": batch.stopping_criterias[req_idx],
|
"stopping_criteria": batch.stopping_criterias[req_idx],
|
||||||
@ -2029,8 +2051,6 @@ class FlashCausalLM(Model):
|
|||||||
for i, req_data in enumerate(requests_to_generate):
|
for i, req_data in enumerate(requests_to_generate):
|
||||||
idx = req_data["idx"]
|
idx = req_data["idx"]
|
||||||
request_id = req_data["request_id"]
|
request_id = req_data["request_id"]
|
||||||
cache_length = req_data["cache_length"]
|
|
||||||
input_length = req_data["input_length"]
|
|
||||||
prefix_offset = req_data["prefix_offset"]
|
prefix_offset = req_data["prefix_offset"]
|
||||||
read_offset = req_data["read_offset"]
|
read_offset = req_data["read_offset"]
|
||||||
stopping_criteria = req_data["stopping_criteria"]
|
stopping_criteria = req_data["stopping_criteria"]
|
||||||
@ -2041,9 +2061,6 @@ class FlashCausalLM(Model):
|
|||||||
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
|
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
|
||||||
top_token_ids = req_data["top_token_ids"]
|
top_token_ids = req_data["top_token_ids"]
|
||||||
top_token_logprobs = req_data["top_token_logprobs"]
|
top_token_logprobs = req_data["top_token_logprobs"]
|
||||||
|
|
||||||
new_input_length = 1
|
|
||||||
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
@ -2159,11 +2176,6 @@ class FlashCausalLM(Model):
|
|||||||
# Update values
|
# Update values
|
||||||
indexs[idx] += n_accepted_ids
|
indexs[idx] += n_accepted_ids
|
||||||
idx_accept_ids[idx] += 1
|
idx_accept_ids[idx] += 1
|
||||||
batch.cache_lengths[i] = new_cache_length
|
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
|
||||||
batch.input_lengths[i] = new_input_length
|
|
||||||
current_length = new_cache_length + new_input_length
|
|
||||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
|
||||||
|
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
|
@ -408,7 +408,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
cache_lengths_tensor = torch.tensor(
|
cache_lengths_tensor = torch.tensor(
|
||||||
past_len, dtype=torch.int32, device=self.device
|
past_len, dtype=torch.int32, device=self.device
|
||||||
@ -433,13 +432,14 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
batch_size,
|
batch_size,
|
||||||
bucketing_ctx=None,
|
bucketing_ctx=None,
|
||||||
)
|
)
|
||||||
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=slots_tensor,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -247,7 +247,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
cache_lengths_tensor = torch.tensor(
|
cache_lengths_tensor = torch.tensor(
|
||||||
past_len, dtype=torch.int32, device=self.device
|
past_len, dtype=torch.int32, device=self.device
|
||||||
@ -279,12 +278,13 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
indices, cross_attention_len = generate_cross_attention_states(
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
cross_attention_states, image_indices, seqlen, 1, False
|
cross_attention_states, image_indices, seqlen, 1, False
|
||||||
)
|
)
|
||||||
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=slots_tensor,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user