mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
remove torch.where to fix incorrect output in hpu graph model
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f0e5faec1a
commit
c55a8caea2
@ -115,13 +115,10 @@ def paged_reshape_and_cache(
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
|
|
||||||
mask = torch.where(slots != -1)
|
|
||||||
slots = slots[mask]
|
|
||||||
block_idx = slots // BLOCK_SIZE
|
block_idx = slots // BLOCK_SIZE
|
||||||
block_offset = slots % BLOCK_SIZE
|
block_offset = slots % BLOCK_SIZE
|
||||||
cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset)
|
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||||
cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset)
|
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||||
|
|
||||||
|
|
||||||
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||||
|
@ -1661,7 +1661,7 @@ class FlashCausalLM(Model):
|
|||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = False
|
kwargs["bypass_hpu_graphs"] = False
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
|
@ -458,7 +458,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
|
@ -283,7 +283,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = False
|
kwargs["bypass_hpu_graphs"] = False
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
|
Loading…
Reference in New Issue
Block a user