fix incorrect output in qwen2 idefics if hpu graph is used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-21 01:01:37 -07:00
parent 36b6612f97
commit fdf0733f56
7 changed files with 58 additions and 10 deletions

View File

@ -95,7 +95,10 @@ class PositionRotaryEmbedding(nn.Module):
mrope_section = rope_scaling["mrope_section"]
if mrope_section is not None:
return RotaryPositionEmbeddingMultimodalSections(
inv_freq, scaling_factor, mrope_section
inv_freq,
scaling_factor,
mrope_section,
config.max_position_embeddings,
)
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
@ -557,8 +560,13 @@ def apply_llama3_scaling(
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list):
super().__init__(inv_freq, scaling_factor)
def __init__(
self,
inv_freq: torch.Tensor,
scaling_factor: float,
sections: list,
max_position_embeddings,
):
self.sections = sections
self._cos_cached = None
self._sin_cached = None
@ -568,6 +576,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
.view(1, 1, -1)
.to(inv_freq.device)
)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int

View File

@ -110,6 +110,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
)
if lm_head_indices is not None:

View File

@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# - replace `==` with torch.where to fix the issue in hpu graph
mask = torch.where(input_ids == self.config.image_token_id)
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
@ -794,6 +795,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
@ -801,6 +803,21 @@ class Idefics2ForConditionalGeneration(nn.Module):
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(

View File

@ -471,7 +471,8 @@ class Idefics3ForConditionalGeneration(nn.Module):
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# - replace `==` with torch.where to fix the issue in hpu graph
mask = torch.where(input_ids == self.config.image_token_id)
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
@ -539,6 +540,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
@ -546,6 +548,21 @@ class Idefics3ForConditionalGeneration(nn.Module):
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(

View File

@ -739,10 +739,12 @@ class Qwen2_5VisionModel(nn.Module):
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
device="cpu",
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(
hidden_states.device
)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
@ -928,7 +930,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,

View File

@ -503,7 +503,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,

View File

@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None):
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
return FlashGemmaForCausalLM(prefix, config, weights)
elif config.model_type == "gemma2":
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,