From f0e5faec1a77b3d6719d8363113db3361a314b32 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 28 Mar 2025 07:01:06 -0700 Subject: [PATCH] fix some issue Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/attention/kv_cache.py | 9 +++++---- .../models/custom_modeling/flash_llava_next.py | 2 +- .../text_generation_server/models/flash_causal_lm.py | 5 +++-- .../text_generation_server/models/flash_vlm_causal_lm.py | 4 ++++ .../text_generation_server/models/mllama_causal_lm.py | 5 ++++- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index 26c80c70..cdd0458b 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -5,6 +5,7 @@ import torch from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.utils.weights import Weights +from vllm_hpu_extension import cache_ops @dataclass @@ -115,12 +116,12 @@ def paged_reshape_and_cache( v_scale: float = 1.0, ): - from vllm_hpu_extension import cache_ops - + mask = torch.where(slots != -1) + slots = slots[mask] block_idx = slots // BLOCK_SIZE block_offset = slots % BLOCK_SIZE - cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) - cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset) def get_kv_scales(weights: Weights, prefix: str) -> KVScales: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 62e8470c..88548042 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -153,7 +153,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): image_features: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" - mask = input_ids == self.config.image_token_index + mask = torch.where(input_ids == self.config.image_token_index) # Let's pray we have enabled enough slots ! try: inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index b0859c3d..816f05d0 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -998,7 +998,8 @@ class FlashCausalLMBatch(Batch): input_ids = [0] * extra_pad + input_ids self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: - logger.error("should not be here, prefill self.input_ids is a tensor") + self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) + input_ids_padded_length.append(extra_pad) self.input_lengths_tensor = torch.tensor( self.input_lengths, dtype=torch.int32, device=device @@ -1660,7 +1661,7 @@ class FlashCausalLM(Model): if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index b5d93cbc..f630a85a 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -457,6 +457,10 @@ class FlashVlmCausalLM(FlashCausalLM): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) + if batch.prefill_cache_indices is not None: + slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index eabbe247..bd123725 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -282,7 +282,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False - + if batch.prefill_cache_indices is not None: + slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids,