diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index b25f9fab..6a83d6a5 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -95,7 +95,10 @@ class PositionRotaryEmbedding(nn.Module): mrope_section = rope_scaling["mrope_section"] if mrope_section is not None: return RotaryPositionEmbeddingMultimodalSections( - inv_freq, scaling_factor, mrope_section + inv_freq, + scaling_factor, + mrope_section, + config.max_position_embeddings, ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] @@ -557,8 +560,13 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): - def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list): - super().__init__(inv_freq, scaling_factor) + def __init__( + self, + inv_freq: torch.Tensor, + scaling_factor: float, + sections: list, + max_position_embeddings, + ): self.sections = sections self._cos_cached = None self._sin_cached = None @@ -568,6 +576,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): .view(1, 1, -1) .to(inv_freq.device) ) + super().__init__(inv_freq, scaling_factor, max_position_embeddings) def _update_cos_sin_cache( self, dtype: torch.dtype, device: torch.device, seqlen: int diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 532f118f..af0f8f89 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -110,6 +110,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, + prefill_cache_indices=None, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 31a01d7c..0a4305ec 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module): ): """In place merges in vision_embeddings with inputs_embeds.""" # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) # Let's pray we have enabled enough slots ! inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds @@ -794,6 +795,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ].contiguous() patch_size = self.config.vision_config.patch_size + """ patches_subgrid = pixel_attention_mask.unfold( dimension=1, size=patch_size, step=patch_size ) @@ -801,6 +803,21 @@ class Idefics2ForConditionalGeneration(nn.Module): dimension=2, size=patch_size, step=patch_size ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) # Get sequence from the vision encoder image_hidden_states = self.vision_model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index ce5e8115..9278a86a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -471,7 +471,8 @@ class Idefics3ForConditionalGeneration(nn.Module): ): """In place merges in vision_embeddings with inputs_embeds.""" # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) # Let's pray we have enabled enough slots ! inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds @@ -539,6 +540,7 @@ class Idefics3ForConditionalGeneration(nn.Module): ].contiguous() patch_size = self.config.vision_config.patch_size + """ patches_subgrid = pixel_attention_mask.unfold( dimension=1, size=patch_size, step=patch_size ) @@ -546,6 +548,21 @@ class Idefics3ForConditionalGeneration(nn.Module): dimension=2, size=patch_size, step=patch_size ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) # Get sequence from the vision encoder image_hidden_states = self.vision_model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 832efdfa..75dd2b40 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -739,10 +739,12 @@ class Qwen2_5VisionModel(nn.Module): cu_window_seqlens = torch.tensor( cu_window_seqlens, - device=hidden_states.device, + device="cpu", dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to( + hidden_states.device + ) # create a cu_seqlens tensor to be used in the attention mask cu_seqlens = torch.repeat_interleave( @@ -928,7 +930,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): image_embeds = self.visual( pixel_values, grid_thw=image_grid_thw ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 856635fd..3b4965a2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -503,7 +503,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_embeds = self.visual( pixel_values, grid_thw=image_grid_thw ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py index 94b8522d..ae704af3 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py @@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None): FlashGemmaForCausalLM, ) - return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + return FlashGemmaForCausalLM(prefix, config, weights) elif config.model_type == "gemma2": from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( FlashGemma2ForCausalLM,