From 9e7e546923abf89b11adcd0c34c98a8e123d12a8 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 22 May 2025 15:21:31 +0800 Subject: [PATCH] Move input_ids to hpu and remove disposal of adapter_meta (#3237) Signed-off-by: Wang, Yi A --- .../layers/attention/common.py | 2 + .../models/flash_causal_lm.py | 227 ++++++++++-------- .../models/flash_vlm_causal_lm.py | 6 +- .../models/mllama_causal_lm.py | 8 +- 4 files changed, 136 insertions(+), 107 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 9bd738fc..5e03cd44 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -90,6 +90,8 @@ class Seqlen: def _async_h2d_tensor_copy(source, device="hpu"): if source is None: return None + if source.device.type == "hpu": + return source assert source.device.type == "cpu", "Source tensor is not present in host memory!" target = torch.empty(source.shape, dtype=source.dtype, device=device) target.copy_(source, non_blocking=True) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index bc0d240e..f8abe5ad 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -634,21 +634,25 @@ class FlashCausalLMBatch(Batch): # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] input_lengths_tensor = self.input_lengths_tensor[indices] cache_lengths_tensor = self.cache_lengths_tensor[indices] # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ) + if self.adapter_meta is not None: + adapter_indices = self.adapter_meta.adapter_indices[indices] + adapter_segments, adapter_segment_indices = find_segments( + adapter_indices + ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) + else: + adapter_meta = None htorch.core.mark_step() return type(self)( batch_id=self.batch_id, @@ -710,6 +714,7 @@ class FlashCausalLMBatch(Batch): max_length = 0 max_input_length = 0 max_current_length = 0 + ADAPTER_TO_INDEX = get_adapter_to_index() for b in batches: total_batch_size += len(b) max_blocks = max(max_blocks, b.max_blocks) @@ -763,14 +768,15 @@ class FlashCausalLMBatch(Batch): cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( total_batch_size ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_segment_builder = SegmentConcatBuilder() - adapter_set = set() + if ADAPTER_TO_INDEX: + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size @@ -821,9 +827,7 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + valid_bsize - index = torch.tensor( - list(range(start_index, end_index)), device=batch.input_ids.device - ) + index = torch.tensor(list(range(start_index, end_index)), device="cpu") top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -847,7 +851,9 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize]) + input_ids.index_copy_( + 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] + ) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) slot_indices.index_copy_( 0, index, batch.slot_indices + cumulative_slots @@ -858,20 +864,21 @@ class FlashCausalLMBatch(Batch): cache_lengths_tensor.index_copy_( 0, index, batch.cache_lengths_tensor[:valid_bsize] ) - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, - batch.adapter_meta.segment_indices, - ) + if ADAPTER_TO_INDEX: + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() @@ -914,7 +921,7 @@ class FlashCausalLMBatch(Batch): else: speculative_ids = None - if adapter_segment_builder is not None: + if ADAPTER_TO_INDEX and adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, @@ -961,7 +968,7 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=adapter_meta, + adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None, hpu_attn_meta=None, next_token_logits=None, speculative_logits=None, @@ -1037,6 +1044,7 @@ class FlashCausalLMBatch(Batch): # need extra pad to match warmup seq extra_pad = max_padded_input_len - self.max_input_length extra_pad_bs = max_padded_bs - len(self) + device = self.all_input_ids_tensor.device if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1047,12 +1055,12 @@ class FlashCausalLMBatch(Batch): input_ids.append(input_id) input_ids_padded_length.append(padded) input_ids = np.concatenate(input_ids, dtype=np.int64) - self.input_ids = torch.tensor(input_ids, dtype=torch.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) elif isinstance(self.input_ids, list): input_ids = self.input_ids[0] input_ids_padded_length.append(extra_pad) input_ids = [0] * extra_pad + input_ids - self.input_ids = torch.tensor(input_ids, dtype=torch.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) input_ids_padded_length.extend([extra_pad] * len(self)) @@ -1245,7 +1253,9 @@ class FlashCausalLMBatch(Batch): self.slot_indices = slot_indices self.prefill_cu_outlens = prefill_cu_outlens - self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) + self.prefill_cache_indices = torch.zeros_like( + self.input_ids, dtype=torch.bool, device="cpu" + ) self.prefill_cache_indices[prefill_cache_indices] = True if all_prefill_logprobs: @@ -1301,21 +1311,24 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states, ) - if adapter_set: - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - else: - adapter_indices = torch.zeros_like(self.input_ids) - adapter_segments = [0, len(adapter_indices)] - adapter_segment_indices = [len(adapter_indices) - 1] + if ADAPTER_TO_INDEX: + if adapter_set: + adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) + adapter_segments, adapter_segment_indices = find_segments( + adapter_indices + ) + else: + adapter_indices = torch.zeros_like(self.input_ids) + adapter_segments = [0, len(adapter_indices)] + adapter_segment_indices = [len(adapter_indices) - 1] - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) - self.adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) def __len__(self): return len(self.requests) @@ -1941,11 +1954,11 @@ class FlashCausalLM(Model): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad seqlen = Seqlen( @@ -1965,7 +1978,7 @@ class FlashCausalLM(Model): ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, @@ -2059,15 +2072,16 @@ class FlashCausalLM(Model): batch.position_ids = batch.position_ids[indices] batch.slot_indices = batch.slot_indices[indices[: len(batch)]] - batch.adapter_meta.adapter_indices = ( - batch.adapter_meta.adapter_indices[indices] - ) + if batch.adapter_meta is not None: + batch.adapter_meta.adapter_indices = ( + batch.adapter_meta.adapter_indices[indices] + ) # For each member of the batch # Cumulative length - accepted_ids = accepted_ids.cpu() - cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) - torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + if batch.speculative_logits is not None: + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) for i in range(len(batch)): batch.all_input_ids_tensor[ i, @@ -2076,6 +2090,20 @@ class FlashCausalLM(Model): + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + accepted_ids = accepted_ids.cpu() + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += ( + batch.input_lengths_tensor + accepted_ids - 1 + ) + batch.input_lengths_tensor = torch.ones_like( + batch.input_lengths_tensor + ) + batch.slot_indices += accepted_ids[: len(batch)] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor index = index.to(batch.all_input_ids_tensor.device) @@ -2088,22 +2116,18 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor.index_put_( (batch_idx, index.long()), next_input_ids ) - next_input_ids = next_input_ids.cpu() - batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.input_ids = next_input_ids + batch.position_ids += 1 + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = torch.ones_like( + batch.input_lengths_tensor + ) + batch.slot_indices += 1 + batch.speculative_ids = speculative_ids - if batch.position_ids.dim() == 2: - # Qwen2_vl case: - batch.position_ids += accepted_ids.unsqueeze(-1) - else: - batch.position_ids += accepted_ids - batch.cache_lengths_tensor += ( - batch.input_lengths_tensor + accepted_ids - 1 - ) - batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) - batch.slot_indices += accepted_ids[: len(batch)] # Does a HPU <-> CPU sync internally - if prefill: + if prefill and batch.adapter_meta is not None: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments( batch.adapter_meta.adapter_indices @@ -2194,30 +2218,33 @@ class FlashCausalLM(Model): prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta - if batch.speculative_ids is not None: - B, speculative_length = batch.speculative_ids.shape - new_length = speculative_length + 1 - adapter_indices = ( - adapter_meta.adapter_indices.unsqueeze(-1) - .expand(B, new_length) - .reshape(-1) - ) - adapter_segments = adapter_meta.adapter_segments * new_length - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_meta.adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_meta.segment_indices, - ) + if adapter_meta is not None: + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1) + .expand(B, new_length) + .reshape(-1) + ) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) - # Assign pointers to adapter weights - # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta( - adapter_meta, - self.layer_to_adapter_weights, - prefill, - batch.prefill_head_indices, - ) + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + prefill, + batch.prefill_head_indices, + ) + else: + adapter_data = None out, speculative_logits = self.forward(batch, adapter_data) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index fd239b3e..e604fd3c 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -627,11 +627,11 @@ class FlashVlmCausalLM(FlashCausalLM): batch.prefilling, seqlen, batch_size ) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad @@ -639,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM): input_lengths=_async_h2d_tensor_copy(input_lengths), ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index db3904a2..771cc0a8 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -190,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): input_ids = np.concatenate(batch.input_ids, dtype=np.int64) else: input_ids = batch.input_ids[0] - batch.input_ids = torch.tensor(input_ids, dtype=torch.int64) + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) @@ -537,11 +537,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad orig_bs = len(batch) @@ -570,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): input_lengths=_async_h2d_tensor_copy(input_lengths), ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache,