mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
fix some issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
376e0507b7
commit
f0e5faec1a
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from vllm_hpu_extension import cache_ops
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -115,12 +116,12 @@ def paged_reshape_and_cache(
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
|
||||
from vllm_hpu_extension import cache_ops
|
||||
|
||||
mask = torch.where(slots != -1)
|
||||
slots = slots[mask]
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset)
|
||||
|
||||
|
||||
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||
|
@ -153,7 +153,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
mask = torch.where(input_ids == self.config.image_token_index)
|
||||
# Let's pray we have enabled enough slots !
|
||||
try:
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
|
@ -998,7 +998,8 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids = [0] * extra_pad + input_ids
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
else:
|
||||
logger.error("should not be here, prefill self.input_ids is a tensor")
|
||||
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
||||
input_ids_padded_length.append(extra_pad)
|
||||
|
||||
self.input_lengths_tensor = torch.tensor(
|
||||
self.input_lengths, dtype=torch.int32, device=device
|
||||
@ -1660,7 +1661,7 @@ class FlashCausalLM(Model):
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
logits, speculative_logits = self.model.forward(
|
||||
|
@ -457,6 +457,10 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -282,7 +282,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user