mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
remove torch.where to fix incorrect output in hpu graph model
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f0e5faec1a
commit
c55a8caea2
@ -115,13 +115,10 @@ def paged_reshape_and_cache(
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
|
||||
mask = torch.where(slots != -1)
|
||||
slots = slots[mask]
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||
|
||||
|
||||
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||
|
@ -1661,7 +1661,7 @@ class FlashCausalLM(Model):
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
logits, speculative_logits = self.model.forward(
|
||||
|
@ -458,7 +458,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
logits, speculative_logits = self.model.forward(
|
||||
|
@ -283,7 +283,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
logits, speculative_logits = self.model.forward(
|
||||
|
Loading…
Reference in New Issue
Block a user