remove torch.where to fix incorrect output in hpu graph model

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-31 22:51:54 -07:00
parent f0e5faec1a
commit c55a8caea2
4 changed files with 5 additions and 8 deletions

View File

@ -115,13 +115,10 @@ def paged_reshape_and_cache(
k_scale: float = 1.0,
v_scale: float = 1.0,
):
mask = torch.where(slots != -1)
slots = slots[mask]
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:

View File

@ -1661,7 +1661,7 @@ class FlashCausalLM(Model):
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
if batch.prefill_cache_indices is not None:
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(

View File

@ -458,7 +458,7 @@ class FlashVlmCausalLM(FlashCausalLM):
cu_seqlen_q=cu_seqlen_prefill,
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(

View File

@ -283,7 +283,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
if batch.prefill_cache_indices is not None:
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(