mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
avoid reshape of all_input_ids_tensor
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
249189d96e
commit
151d6638d3
@ -428,10 +428,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for i, input_ids in enumerate(all_input_ids):
|
for i, input_ids in enumerate(all_input_ids):
|
||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
# Create tensors on device
|
# put on cpu temporarily, move to hpu in prepare_for_prefill
|
||||||
all_input_ids_tensor = torch.tensor(
|
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||||
|
|
||||||
@ -784,9 +782,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_batch_size, max_blocks)
|
||||||
)
|
)
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor
|
||||||
(total_batch_size, max_length)
|
|
||||||
)
|
|
||||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
)
|
)
|
||||||
@ -829,9 +825,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
|
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
|
||||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||||
all_input_ids_tensor[
|
if i > 0:
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
all_input_ids_tensor.index_copy_(
|
||||||
] = batch.all_input_ids_tensor[:valid_bsize, :max_length]
|
0, index.to("hpu"), batch.all_input_ids_tensor[:valid_bsize, :]
|
||||||
|
)
|
||||||
|
|
||||||
block_tables_tensor[
|
block_tables_tensor[
|
||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
@ -987,7 +984,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
else:
|
else:
|
||||||
padded_bs = self.input_ids.shape[0]
|
padded_bs = self.input_ids.shape[0]
|
||||||
slots = self.slots[self.slot_indices]
|
slots = self.slots[self.slot_indices]
|
||||||
extra_pad = padded_bs - self.input_ids.shape[0]
|
|
||||||
|
|
||||||
self.hpu_attn_meta = prepare_for_decode(
|
self.hpu_attn_meta = prepare_for_decode(
|
||||||
dtype,
|
dtype,
|
||||||
@ -998,17 +994,20 @@ class FlashCausalLMBatch(Batch):
|
|||||||
padded_bs,
|
padded_bs,
|
||||||
bucketing_ctx,
|
bucketing_ctx,
|
||||||
)
|
)
|
||||||
self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0)
|
self.input_ids = F.pad(
|
||||||
self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1)
|
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
|
||||||
|
)
|
||||||
|
self.position_ids = F.pad(
|
||||||
|
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
|
||||||
|
)
|
||||||
self.input_lengths_tensor = F.pad(
|
self.input_lengths_tensor = F.pad(
|
||||||
self.input_lengths_tensor, (0, extra_pad), value=0
|
self.input_lengths_tensor,
|
||||||
|
(0, padded_bs - self.input_lengths_tensor.shape[0]),
|
||||||
|
value=0,
|
||||||
)
|
)
|
||||||
self.cache_lengths_tensor = F.pad(
|
self.cache_lengths_tensor = F.pad(
|
||||||
self.cache_lengths_tensor, (0, extra_pad), value=0
|
self.cache_lengths_tensor,
|
||||||
)
|
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
|
||||||
self.all_input_ids_tensor = F.pad(
|
|
||||||
self.all_input_ids_tensor,
|
|
||||||
(0, 0, 0, extra_pad),
|
|
||||||
value=0,
|
value=0,
|
||||||
)
|
)
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
@ -1028,7 +1027,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
fsm_grammar_states,
|
fsm_grammar_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
|
def prepare_for_prefill(
|
||||||
|
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||||
|
):
|
||||||
# Prepare values if we need to continue prefilling
|
# Prepare values if we need to continue prefilling
|
||||||
# Speculation must be ignored while we prefill even with chunking
|
# Speculation must be ignored while we prefill even with chunking
|
||||||
# it simplifies everything
|
# it simplifies everything
|
||||||
@ -1044,7 +1045,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# need extra pad to match warmup seq
|
# need extra pad to match warmup seq
|
||||||
extra_pad = max_padded_input_len - self.max_input_length
|
extra_pad = max_padded_input_len - self.max_input_length
|
||||||
extra_pad_bs = max_padded_bs - len(self)
|
extra_pad_bs = max_padded_bs - len(self)
|
||||||
device = self.all_input_ids_tensor.device
|
device = "hpu"
|
||||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||||
input_ids_padded_length = []
|
input_ids_padded_length = []
|
||||||
input_ids = []
|
input_ids = []
|
||||||
@ -1288,12 +1289,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.prefill_next_token_indices = (
|
self.prefill_next_token_indices = (
|
||||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||||
)
|
)
|
||||||
|
all_input_ids_tensor = torch.zeros(
|
||||||
self.all_input_ids_tensor = F.pad(
|
(max_padded_bs, max_total_tokens), dtype=torch.int64, device="hpu"
|
||||||
self.all_input_ids_tensor,
|
|
||||||
(0, 0, 0, extra_pad_bs),
|
|
||||||
value=0,
|
|
||||||
)
|
)
|
||||||
|
for i in range(len(self)):
|
||||||
|
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
|
||||||
|
self.all_input_ids_tensor[i]
|
||||||
|
)
|
||||||
|
self.all_input_ids_tensor = all_input_ids_tensor
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
||||||
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
|
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
|
||||||
@ -1459,6 +1463,8 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
self.bucketing_ctx = None
|
self.bucketing_ctx = None
|
||||||
|
self.max_total_tokens = None
|
||||||
|
self.max_input_tokens = None
|
||||||
htorch.core.hpu_set_env()
|
htorch.core.hpu_set_env()
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
@ -1564,6 +1570,14 @@ class FlashCausalLM(Model):
|
|||||||
logger.info,
|
logger.info,
|
||||||
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
|
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
|
||||||
)
|
)
|
||||||
|
if max_total_tokens is None:
|
||||||
|
max_total_tokens = sum(batch.input_lengths)
|
||||||
|
|
||||||
|
if max_input_tokens is None:
|
||||||
|
max_input_tokens = max_total_tokens - 1
|
||||||
|
|
||||||
|
self.max_total_tokens = max_total_tokens
|
||||||
|
self.max_input_tokens = max_input_tokens
|
||||||
try:
|
try:
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
batch.num_blocks,
|
batch.num_blocks,
|
||||||
@ -1597,11 +1611,6 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||||
if max_total_tokens is None:
|
|
||||||
max_total_tokens = sum(batch.input_lengths)
|
|
||||||
|
|
||||||
if max_input_tokens is None:
|
|
||||||
max_input_tokens = max_total_tokens - 1
|
|
||||||
|
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
@ -2017,7 +2026,9 @@ class FlashCausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
speculative_ids,
|
speculative_ids,
|
||||||
) = batch.next_token_chooser(
|
) = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
batch.all_input_ids_tensor[
|
||||||
|
: batch.next_token_logits.shape[0], : batch.max_current_length
|
||||||
|
],
|
||||||
batch.next_token_logits,
|
batch.next_token_logits,
|
||||||
speculate,
|
speculate,
|
||||||
batch.speculative_ids,
|
batch.speculative_ids,
|
||||||
@ -2033,9 +2044,14 @@ class FlashCausalLM(Model):
|
|||||||
if batch.valid_indices is not None:
|
if batch.valid_indices is not None:
|
||||||
next_token_logprobs = next_token_logprobs.cpu()
|
next_token_logprobs = next_token_logprobs.cpu()
|
||||||
accepted_ids = accepted_ids.cpu()
|
accepted_ids = accepted_ids.cpu()
|
||||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
index = torch.arange(
|
||||||
batch.valid_indices
|
0,
|
||||||
]
|
len(batch.valid_indices),
|
||||||
|
device=batch.all_input_ids_tensor.device,
|
||||||
|
)
|
||||||
|
batch.all_input_ids_tensor.index_copy_(
|
||||||
|
0, index, batch.all_input_ids_tensor[batch.valid_indices]
|
||||||
|
)
|
||||||
next_input_ids = next_input_ids[batch.valid_indices]
|
next_input_ids = next_input_ids[batch.valid_indices]
|
||||||
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
||||||
accepted_ids = accepted_ids[batch.valid_indices]
|
accepted_ids = accepted_ids[batch.valid_indices]
|
||||||
@ -2208,9 +2224,12 @@ class FlashCausalLM(Model):
|
|||||||
batch.max_input_length
|
batch.max_input_length
|
||||||
),
|
),
|
||||||
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
|
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
|
||||||
|
self.max_total_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch.prepare_for_prefill(batch.max_input_length, len(batch))
|
batch.prepare_for_prefill(
|
||||||
|
batch.max_input_length, len(batch), self.max_total_tokens
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch.prepare_for_decode(
|
batch.prepare_for_decode(
|
||||||
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
||||||
|
Loading…
Reference in New Issue
Block a user