diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index e4fd3325..907fa163 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -31,6 +31,7 @@ import torch.nn.functional as F from text_generation_server.layers.layernorm import ( FastLayerNorm, + FastRMSNorm ) from text_generation_server.layers import ( TensorParallelColumnLinear, @@ -369,6 +370,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): self.lm_head = FastLinear.load( prefix="lm_head", weights=weights, config=config, bias=False ) + self.norm = FastRMSNorm.load( + prefix="model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) self.device = weights.device def get_position_ids( @@ -488,7 +494,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): inputs_embeds[input_ids == self.image_token_id] = image_embeds position_ids = self.get_position_ids(input_ids, image_grid_thw) - outputs = self.text_model( + hidden_states = self.text_model( inputs_embeds=inputs_embeds.squeeze(0), position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -500,6 +506,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, ) - - logits = self.lm_head(outputs) + hidden_states, _ = self.norm(hidden_states) + logits = self.lm_head(hidden_states) return logits, None