diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index cdd0458b..d238cdb9 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -115,13 +115,10 @@ def paged_reshape_and_cache( k_scale: float = 1.0, v_scale: float = 1.0, ): - - mask = torch.where(slots != -1) - slots = slots[mask] block_idx = slots // BLOCK_SIZE block_offset = slots % BLOCK_SIZE - cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset) - cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) def get_kv_scales(weights: Weights, prefix: str) -> KVScales: diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 816f05d0..a4d58596 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1661,7 +1661,7 @@ class FlashCausalLM(Model): if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: - slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index f630a85a..208ab358 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -458,7 +458,7 @@ class FlashVlmCausalLM(FlashCausalLM): cu_seqlen_q=cu_seqlen_prefill, ) if batch.prefill_cache_indices is not None: - slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index bd123725..e034ed49 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -283,7 +283,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: - slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward(