diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 3a91e94c..76e64f3a 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -67,6 +67,10 @@ try: from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( FlashGemma2ForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_gemma3_modeling import ( + Gemma3ForConditionalGeneration, + FlashGemma3ForCausalLM, + ) from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( FlashDbrxForCausalLM, DbrxConfig, @@ -220,6 +224,16 @@ class ModelType(enum.Enum): "name": "Gemma2", "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", } + GEMMA3 = { + "type": "gemma3", + "name": "Gemma3", + "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d", + } + GEMMA3_TEXT = { + "type": "gemma3_text", + "name": "Gemma3 Text", + "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d", + } COHERE = { "type": "cohere", "name": "Cohere", @@ -630,6 +644,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, @@ -675,6 +690,34 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif model_type == GEMMA3: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Gemma3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + support_chunking=False, + ) + elif model_type == GEMMA3_TEXT: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) elif model_type == COHERE: return FlashCausalLM( model_id=model_id, @@ -864,6 +907,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py new file mode 100644 index 00000000..92f059bc --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -0,0 +1,733 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +from torch import nn +from typing import Optional, List, Tuple +import copy + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, + # + SpeculativeHead, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) + +import torch + + +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) + + +from text_generation_server.layers.attention.kv_cache import get_kv_scales +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from text_generation_server.utils.weights import UnquantizedWeight +from transformers.activations import ACT2FN +from text_generation_server.layers.attention import ( + paged_attention, + attention, + Seqlen, + set_block_mapping, + HPUPagedAttentionMetadata, +) +import habana_frameworks.torch as htorch + +ATTENTION_TYPE_GLOBAL = "global" +ATTENTION_TYPE_LOCAL = "local_sliding" + + +class Gemma3FastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix: str, weights, eps=1e-6): + dtype = weights.dtype + weights.dtype = torch.float32 + weight = weights.get_tensor(f"{prefix}.weight") + 1 + weights.dtype = dtype + new = cls(weight, eps) + new.dtype = dtype + return new + + # perform the multiplication in full precision and downcast after + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * self.weight + return hidden_states.to(self.dtype), residual + + +def load_attention(config, prefix: str, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + ) + + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.head_dim + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias=None)) + + +class FlashGemma3Attention(torch.nn.Module): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.head_dim + self.causal = causal + if is_sliding: + self.window_size = config.sliding_window + # TODO: remove this hack to support local sliding window + config = copy.deepcopy(config) + config.rope_scaling = dict(rope_type="default") + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.head_dim, + base=config.rope_local_base_freq, + device=weights.device, + ) + else: + self.window_size = -1 + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.head_dim, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = ( + config.query_pre_attn_scalar**-0.5 + if config.query_pre_attn_scalar is not None + else None + ) + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + self.softcap = None # config.attn_logit_softcapping + + query_key_value = load_attention(config, prefix, weights) + self.query_key_value = TensorParallelMultiAdapterLinear.load( + query_key_value, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + self.head_size * config.num_attention_heads, + self.head_size * config.num_key_value_heads, + self.head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + layer_id, + "o_proj", + process_group=weights.process_group, + ) + + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + self.q_norm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps + ) + self.k_norm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps + ) + self.enable_gqa = self.num_heads != self.num_key_value_heads + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + hpu_attention_meta, + ): + + qkv = self.query_key_value(hidden_states, adapter_data) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + + kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size) + key = kv[:, 0] + value = kv[:, 1] + + query = query.reshape(-1, self.head_size) + key = key.reshape(-1, self.head_size) + + query, _ = self.q_norm(query.contiguous()) + key, _ = self.k_norm(key.contiguous()) + + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_key_value_heads, self.head_size) + value = value.view(-1, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, key, cos, sin) + + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) + # Prefill + if cu_seqlen_prefill is not None: + # sdpa + attn_output = attention( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, + window_size_left=self.window_size, + softcap=self.softcap, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + seqlen, + softcap=self.softcap, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) + + +class Gemma3MLP(nn.Module): + def __init__(self, prefix, config, weights, layer_id): + super().__init__() + act = config.hidden_activation + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_id, + "down_proj", + process_group=weights.process_group, + ) + + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) + + +class FlashGemma3Layer(nn.Module): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): + super().__init__() + self.self_attn = FlashGemma3Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + causal=causal, + is_sliding=is_sliding, + ) + self.mlp = Gemma3MLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id + ) + + self.input_layernorm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.pre_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.post_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + hpu_attention_meta, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + hpu_attention_meta, + ) + + # faster post attention rms norm + normed_attn_res_output, _ = self.post_attention_layernorm(attn_output) + normed_attn_res_output = normed_attn_res_output + res + res = normed_attn_res_output + + pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output) + mlp_output = self.mlp(pre_normed, adapter_data) + post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output) + + return post_hidden_states, normed_attn_res_output + + +class FlashGemma3Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights, causal: bool): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.layers = nn.ModuleList( + [ + FlashGemma3Layer( + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, + layer_id=layer_id, + causal=causal, + is_sliding=bool((layer_id + 1) % config.sliding_window_pattern), + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + adapter_data: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) + hidden_states = inputs_embeds + + residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + + residual = None + for i, layer in enumerate(self.layers): + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = layer.self_attn.rotary_emb.get_cos_sin(position_ids) + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + adapter_data, + hpu_attention_meta, + ) + if lazy_mode: + htorch.core.mark_step() + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashGemma3ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): + super().__init__() + + embed_norm = config.hidden_size**0.5 + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.model = FlashGemma3Model( + prefix=prefix, config=config, weights=weights, causal=causal + ) + self.lm_head = SpeculativeHead.load( + prefix=( + f"{prefix}.embed_tokens" + if config.tie_word_embeddings + else f"{prefix}.lm_head" + ), + config=config, + weights=weights, + ) + # self.softcap = config.attn_logit_softcapping + # assert isinstance(self.softcap, float) + self.softcap = None + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_embeds = self.embed_tokens(input_ids) + + hidden_states = self.model( + input_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + hpu_attention_meta, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + + return logits, speculative_logits + + +class Gemma3MultimodalInputProjection(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.mm_input_projection_weight = weights.get_tensor( + "multi_modal_projector.mm_input_projection_weight" + ) + + self.mm_soft_emb_norm = Gemma3FastRMSNorm.load( + prefix=f"{prefix}.mm_soft_emb_norm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) + + self.patches_per_image = int( + config.vision_config.image_size // config.vision_config.patch_size + ) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d( + kernel_size=self.kernel_size, stride=self.kernel_size + ) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight + ) + return projected_vision_outputs.type_as(vision_outputs) + + +class Gemma3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.config = config + + if config.vision_config is not None: + + config.vision_config.quantize = config.quantize + + self.post_vision_model_layernorm = nn.LayerNorm.load( + prefix="vision_tower.vision_model.post_layernorm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) + + self.multimodal_projector = Gemma3MultimodalInputProjection( + prefix="multi_modal_projector", + config=config, + weights=weights, + ) + + text_config = config.text_config + text_config.speculator = config.speculator + text_config.quantize = config.quantize + + self.vision_model = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + + self.text_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + else: + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + self.text_model = load_text_model( + prefix=prefix, + config=config.text_config, + weights=weights, + ) + + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + self.dtype = weights.dtype + + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + pixel_values = pixel_values.to(dtype=self.dtype) + image_outputs = self.vision_model(pixel_values) + vision_outputs = self.post_vision_model_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multimodal_projector(vision_outputs) + image_features = image_features.view(-1, image_features.shape[-1]) + return image_features + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # Replace the image token embeddings with the vision features + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + inputs_embeds[image_token_mask] = vision_embeds.view( + -1, vision_embeds.shape[-1] + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if cu_seqlen_prefill is not None: + position_ids += 1 + + if attention_mask is not None: + min_dtype = torch.finfo(inputs_embeds.dtype).min + # prefill may be larger than sliding window + effective_seq_len = max( + position_ids.shape[0], self.config.text_config.sliding_window + ) + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), + diagonal=-self.config.text_config.sliding_window, + ) + attention_mask_local = torch.where( + sliding_window_mask, min_dtype, attention_mask + ) + offset = max(0, position_ids.shape[0] - effective_seq_len) + attention_mask_local = attention_mask_local[ + :, :, :, offset : offset + effective_seq_len + ] + else: + attention_mask_local = None + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + adapter_data=adapter_data, + ) + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py index ae704af3..23cba7b0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py @@ -23,6 +23,12 @@ def load_text_model(prefix, config, weights, name=None): ) return FlashGemma2ForCausalLM(prefix, config, weights) + elif config.model_type == "gemma3" or config.model_type == "gemma3_text": + from text_generation_server.models.custom_modeling.flash_gemma3_modeling import ( + FlashGemma3ForCausalLM, + ) + + return FlashGemma3ForCausalLM(prefix, config, weights) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, @@ -42,13 +48,20 @@ def load_vision_model(prefix, config, weights): return CLIPVisionTransformer( prefix=f"{prefix}.vision_model", config=config, weights=weights ) - if config.model_type == "siglip_vision_model": + if ( + config.model_type == "siglip_vision_model" + or config.model_type == "gemma3_vision" + ): from text_generation_server.models.custom_modeling.siglip import ( SiglipVisionTransformer, ) + # TODO: ensure that using the prefix doesn't break any existing models + # that rely on the old prefix (update the old models if necessary) return SiglipVisionTransformer( - prefix="vision_tower.vision_model", config=config, weights=weights + prefix=f"{prefix}.vision_model", + config=config, + weights=weights, ) else: raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 51dca2e9..ca62560e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -80,19 +80,6 @@ from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) -# Will be set in init -SLIDING_WINDOW: Optional[int] = None - - -def set_sliding_window(sliding_window: int): - global SLIDING_WINDOW - SLIDING_WINDOW = sliding_window - - -def get_sliding_windows() -> int: - global SLIDING_WINDOW - return SLIDING_WINDOW - def prepare_for_decode( dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx @@ -1112,7 +1099,6 @@ class FlashCausalLMBatch(Batch): self.cache_lengths_tensor, (0, extra_pad_bs), value=0 ) - sliding_window = get_sliding_windows() position_ids = [] slot_indices = [] prefill_cache_indices = [] @@ -1178,9 +1164,7 @@ class FlashCausalLMBatch(Batch): # Create tensor to slice into the kv tensor in prefill # hpu need request_prefill_cache_indices to skip padding in kv cache - sliding_window = get_sliding_windows() - if sliding_window is None: - sliding_window = input_length + sliding_window = input_length cumulative_length += input_ids_padded_length[i] if sliding_window is not None: request_prefill_cache_indices = torch.arange( @@ -1457,9 +1441,7 @@ class FlashCausalLM(Model): if text_config is not None: config = text_config - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: + if getattr(config, "sliding_window", None) is None: config.sliding_window = None self.num_layers = config.num_hidden_layers diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 4220f40e..54c35c58 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -1001,17 +1001,8 @@ class FlashVlmCausalLM(FlashCausalLM): attention_mask = None attention_mask_forward = None - if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: - attention_mask = self.model.get_attention_mask( - input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True - ) - min_dtype = torch.finfo(self.dtype).min - attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to( - input_ids.device - ) - attention_mask = attention_mask.reshape(-1) if self.model.config.model_type == "llama4": - attention_mask = (input_ids != 0).long() + attention_mask = (input_ids != self.tokenizer.pad_token_id).long() attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1) if cu_seqlen_prefill is None and self.max_past() is not None: diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py index 66c69bc1..b936bb69 100644 --- a/backends/gaudi/server/text_generation_server/models/model.py +++ b/backends/gaudi/server/text_generation_server/models/model.py @@ -79,7 +79,7 @@ class Model(ABC): requires_padding=self.requires_padding, dtype=str(self.dtype), device_type=self.device.type, - window_size=self.sliding_window, + window_size=None, speculate=self.speculate, block_size=BLOCK_SIZE, )