From 65558b32f4f92d2da80a1becb25ab03d4fe9b0e2 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 15:14:02 +0000 Subject: [PATCH] fix: add norm after text output --- .../models/custom_modeling/qwen2_vl.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index e4fd3325..907fa163 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -31,6 +31,7 @@ import torch.nn.functional as F from text_generation_server.layers.layernorm import ( FastLayerNorm, + FastRMSNorm ) from text_generation_server.layers import ( TensorParallelColumnLinear, @@ -369,6 +370,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): self.lm_head = FastLinear.load( prefix="lm_head", weights=weights, config=config, bias=False ) + self.norm = FastRMSNorm.load( + prefix="model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) self.device = weights.device def get_position_ids( @@ -488,7 +494,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): inputs_embeds[input_ids == self.image_token_id] = image_embeds position_ids = self.get_position_ids(input_ids, image_grid_thw) - outputs = self.text_model( + hidden_states = self.text_model( inputs_embeds=inputs_embeds.squeeze(0), position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -500,6 +506,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, ) - - logits = self.lm_head(outputs) + hidden_states, _ = self.norm(hidden_states) + logits = self.lm_head(hidden_states) return logits, None