mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 03:52:08 +00:00
fix incorrect output in qwen2 idefics if hpu graph is used
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
36b6612f97
commit
fdf0733f56
@ -95,7 +95,10 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
mrope_section = rope_scaling["mrope_section"]
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
if mrope_section is not None:
|
if mrope_section is not None:
|
||||||
return RotaryPositionEmbeddingMultimodalSections(
|
return RotaryPositionEmbeddingMultimodalSections(
|
||||||
inv_freq, scaling_factor, mrope_section
|
inv_freq,
|
||||||
|
scaling_factor,
|
||||||
|
mrope_section,
|
||||||
|
config.max_position_embeddings,
|
||||||
)
|
)
|
||||||
elif rope_type == "dynamic":
|
elif rope_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
@ -557,8 +560,13 @@ def apply_llama3_scaling(
|
|||||||
|
|
||||||
|
|
||||||
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||||
def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list):
|
def __init__(
|
||||||
super().__init__(inv_freq, scaling_factor)
|
self,
|
||||||
|
inv_freq: torch.Tensor,
|
||||||
|
scaling_factor: float,
|
||||||
|
sections: list,
|
||||||
|
max_position_embeddings,
|
||||||
|
):
|
||||||
self.sections = sections
|
self.sections = sections
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
self._sin_cached = None
|
self._sin_cached = None
|
||||||
@ -568,6 +576,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
.view(1, 1, -1)
|
.view(1, 1, -1)
|
||||||
.to(inv_freq.device)
|
.to(inv_freq.device)
|
||||||
)
|
)
|
||||||
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
|
|
||||||
def _update_cos_sin_cache(
|
def _update_cos_sin_cache(
|
||||||
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
||||||
|
@ -110,6 +110,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
):
|
):
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
# mask = input_ids == self.config.image_token_index
|
# mask = input_ids == self.config.image_token_index
|
||||||
mask = input_ids == self.config.image_token_id
|
# - replace `==` with torch.where to fix the issue in hpu graph
|
||||||
|
mask = torch.where(input_ids == self.config.image_token_id)
|
||||||
# Let's pray we have enabled enough slots !
|
# Let's pray we have enabled enough slots !
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
@ -794,6 +795,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
].contiguous()
|
].contiguous()
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
"""
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
dimension=1, size=patch_size, step=patch_size
|
dimension=1, size=patch_size, step=patch_size
|
||||||
)
|
)
|
||||||
@ -801,6 +803,21 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
dimension=2, size=patch_size, step=patch_size
|
dimension=2, size=patch_size, step=patch_size
|
||||||
)
|
)
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
"""
|
||||||
|
# hpu does none support unfold
|
||||||
|
conv_kernel = torch.ones(
|
||||||
|
[1, 1, patch_size, patch_size],
|
||||||
|
dtype=pixel_values.dtype,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
patches_subgrid = torch.nn.functional.conv2d(
|
||||||
|
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||||
|
conv_kernel,
|
||||||
|
stride=patch_size,
|
||||||
|
).squeeze(1)
|
||||||
|
patch_attention_mask = torch.eq(
|
||||||
|
patches_subgrid, (patch_size * patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
# Get sequence from the vision encoder
|
||||||
image_hidden_states = self.vision_model(
|
image_hidden_states = self.vision_model(
|
||||||
|
@ -471,7 +471,8 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
):
|
):
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
# mask = input_ids == self.config.image_token_index
|
# mask = input_ids == self.config.image_token_index
|
||||||
mask = input_ids == self.config.image_token_id
|
# - replace `==` with torch.where to fix the issue in hpu graph
|
||||||
|
mask = torch.where(input_ids == self.config.image_token_id)
|
||||||
# Let's pray we have enabled enough slots !
|
# Let's pray we have enabled enough slots !
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
@ -539,6 +540,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
].contiguous()
|
].contiguous()
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
"""
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
dimension=1, size=patch_size, step=patch_size
|
dimension=1, size=patch_size, step=patch_size
|
||||||
)
|
)
|
||||||
@ -546,6 +548,21 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
dimension=2, size=patch_size, step=patch_size
|
dimension=2, size=patch_size, step=patch_size
|
||||||
)
|
)
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
"""
|
||||||
|
# hpu does none support unfold
|
||||||
|
conv_kernel = torch.ones(
|
||||||
|
[1, 1, patch_size, patch_size],
|
||||||
|
dtype=pixel_values.dtype,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
patches_subgrid = torch.nn.functional.conv2d(
|
||||||
|
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||||
|
conv_kernel,
|
||||||
|
stride=patch_size,
|
||||||
|
).squeeze(1)
|
||||||
|
patch_attention_mask = torch.eq(
|
||||||
|
patches_subgrid, (patch_size * patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
# Get sequence from the vision encoder
|
||||||
image_hidden_states = self.vision_model(
|
image_hidden_states = self.vision_model(
|
||||||
|
@ -739,10 +739,12 @@ class Qwen2_5VisionModel(nn.Module):
|
|||||||
|
|
||||||
cu_window_seqlens = torch.tensor(
|
cu_window_seqlens = torch.tensor(
|
||||||
cu_window_seqlens,
|
cu_window_seqlens,
|
||||||
device=hidden_states.device,
|
device="cpu",
|
||||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
)
|
)
|
||||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(
|
||||||
|
hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
# create a cu_seqlens tensor to be used in the attention mask
|
# create a cu_seqlens tensor to be used in the attention mask
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
@ -928,7 +930,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
image_embeds = self.visual(
|
image_embeds = self.visual(
|
||||||
pixel_values, grid_thw=image_grid_thw
|
pixel_values, grid_thw=image_grid_thw
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
mask = torch.where(input_ids == self.image_token_id)
|
||||||
|
inputs_embeds[mask] = image_embeds
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -503,7 +503,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
image_embeds = self.visual(
|
image_embeds = self.visual(
|
||||||
pixel_values, grid_thw=image_grid_thw
|
pixel_values, grid_thw=image_grid_thw
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
mask = torch.where(input_ids == self.image_token_id)
|
||||||
|
inputs_embeds[mask] = image_embeds
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
||||||
elif config.model_type == "gemma2":
|
elif config.model_type == "gemma2":
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
FlashGemma2ForCausalLM,
|
FlashGemma2ForCausalLM,
|
||||||
|
Loading…
Reference in New Issue
Block a user