fix some issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-28 07:01:06 -07:00
parent 376e0507b7
commit f0e5faec1a
5 changed files with 17 additions and 8 deletions

View File

@ -5,6 +5,7 @@ import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass
@ -115,12 +116,12 @@ def paged_reshape_and_cache(
v_scale: float = 1.0,
):
from vllm_hpu_extension import cache_ops
mask = torch.where(slots != -1)
slots = slots[mask]
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:

View File

@ -153,7 +153,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
mask = torch.where(input_ids == self.config.image_token_index)
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])

View File

@ -998,7 +998,8 @@ class FlashCausalLMBatch(Batch):
input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
logger.error("should not be here, prefill self.input_ids is a tensor")
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
input_ids_padded_length.append(extra_pad)
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
@ -1660,7 +1661,7 @@ class FlashCausalLM(Model):
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(

View File

@ -457,6 +457,10 @@ class FlashVlmCausalLM(FlashCausalLM):
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,

View File

@ -282,7 +282,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
if batch.prefill_cache_indices is not None:
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,