From 151d6638d332dd37db46e3951cf8a7eb955b8014 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 2 Jun 2025 22:17:31 -0700 Subject: [PATCH] avoid reshape of all_input_ids_tensor Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 91 +++++++++++-------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index f8abe5ad..e9a320d9 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -428,10 +428,8 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids - # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) + # put on cpu temporarily, move to hpu in prepare_for_prefill + all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) @@ -784,9 +782,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( - (total_batch_size, max_length) - ) + all_input_ids_tensor = batches[0].all_input_ids_tensor top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) @@ -829,9 +825,10 @@ class FlashCausalLMBatch(Batch): index = torch.tensor(list(range(start_index, end_index)), device="cpu") top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) - all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:valid_bsize, :max_length] + if i > 0: + all_input_ids_tensor.index_copy_( + 0, index.to("hpu"), batch.all_input_ids_tensor[:valid_bsize, :] + ) block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] @@ -987,7 +984,6 @@ class FlashCausalLMBatch(Batch): else: padded_bs = self.input_ids.shape[0] slots = self.slots[self.slot_indices] - extra_pad = padded_bs - self.input_ids.shape[0] self.hpu_attn_meta = prepare_for_decode( dtype, @@ -998,17 +994,20 @@ class FlashCausalLMBatch(Batch): padded_bs, bucketing_ctx, ) - self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0) - self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1) + self.input_ids = F.pad( + self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 + ) + self.position_ids = F.pad( + self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 + ) self.input_lengths_tensor = F.pad( - self.input_lengths_tensor, (0, extra_pad), value=0 + self.input_lengths_tensor, + (0, padded_bs - self.input_lengths_tensor.shape[0]), + value=0, ) self.cache_lengths_tensor = F.pad( - self.cache_lengths_tensor, (0, extra_pad), value=0 - ) - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad), + self.cache_lengths_tensor, + (0, padded_bs - self.cache_lengths_tensor.shape[0]), value=0, ) next_token_chooser_parameters = [] @@ -1028,7 +1027,9 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states, ) - def prepare_for_prefill(self, max_padded_input_len, max_padded_bs): + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything @@ -1044,7 +1045,7 @@ class FlashCausalLMBatch(Batch): # need extra pad to match warmup seq extra_pad = max_padded_input_len - self.max_input_length extra_pad_bs = max_padded_bs - len(self) - device = self.all_input_ids_tensor.device + device = "hpu" if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1288,12 +1289,15 @@ class FlashCausalLMBatch(Batch): self.prefill_next_token_indices = ( self.prefill_next_token_indices + input_ids_padded_length_tensor ) - - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad_bs), - value=0, + all_input_ids_tensor = torch.zeros( + (max_padded_bs, max_total_tokens), dtype=torch.int64, device="hpu" ) + for i in range(len(self)): + all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = ( + self.all_input_ids_tensor[i] + ) + self.all_input_ids_tensor = all_input_ids_tensor + next_token_chooser_parameters = [] next_token_chooser_parameters.extend([r.parameters for r in self.requests]) pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs) @@ -1459,6 +1463,8 @@ class FlashCausalLM(Model): self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.bucketing_ctx = None + self.max_total_tokens = None + self.max_input_tokens = None htorch.core.hpu_set_env() if htorch.utils.internal.is_lazy(): htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) @@ -1564,6 +1570,14 @@ class FlashCausalLM(Model): logger.info, f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}", ) + if max_total_tokens is None: + max_total_tokens = sum(batch.input_lengths) + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + + self.max_total_tokens = max_total_tokens + self.max_input_tokens = max_input_tokens try: self.init_kv_cache( batch.num_blocks, @@ -1597,11 +1611,6 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") - if max_total_tokens is None: - max_total_tokens = sum(batch.input_lengths) - - if max_input_tokens is None: - max_input_tokens = max_total_tokens - 1 self.kv_cache = [] empty_cache() @@ -2017,7 +2026,9 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], + batch.all_input_ids_tensor[ + : batch.next_token_logits.shape[0], : batch.max_current_length + ], batch.next_token_logits, speculate, batch.speculative_ids, @@ -2033,9 +2044,14 @@ class FlashCausalLM(Model): if batch.valid_indices is not None: next_token_logprobs = next_token_logprobs.cpu() accepted_ids = accepted_ids.cpu() - batch.all_input_ids_tensor = batch.all_input_ids_tensor[ - batch.valid_indices - ] + index = torch.arange( + 0, + len(batch.valid_indices), + device=batch.all_input_ids_tensor.device, + ) + batch.all_input_ids_tensor.index_copy_( + 0, index, batch.all_input_ids_tensor[batch.valid_indices] + ) next_input_ids = next_input_ids[batch.valid_indices] next_token_logprobs = next_token_logprobs[batch.valid_indices] accepted_ids = accepted_ids[batch.valid_indices] @@ -2208,9 +2224,12 @@ class FlashCausalLM(Model): batch.max_input_length ), self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)), + self.max_total_tokens, ) else: - batch.prepare_for_prefill(batch.max_input_length, len(batch)) + batch.prepare_for_prefill( + batch.max_input_length, len(batch), self.max_total_tokens + ) else: batch.prepare_for_decode( self.dtype, self.use_contiguous_pa, self.bucketing_ctx