fix: add norm after text output

This commit is contained in:
David Holtz 2024-10-28 15:14:02 +00:00 committed by drbh
parent aa2aa9f915
commit 65558b32f4

View File

@ -31,6 +31,7 @@ import torch.nn.functional as F
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -369,6 +370,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.lm_head = FastLinear.load( self.lm_head = FastLinear.load(
prefix="lm_head", weights=weights, config=config, bias=False prefix="lm_head", weights=weights, config=config, bias=False
) )
self.norm = FastRMSNorm.load(
prefix="model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.device = weights.device self.device = weights.device
def get_position_ids( def get_position_ids(
@ -488,7 +494,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
inputs_embeds[input_ids == self.image_token_id] = image_embeds inputs_embeds[input_ids == self.image_token_id] = image_embeds
position_ids = self.get_position_ids(input_ids, image_grid_thw) position_ids = self.get_position_ids(input_ids, image_grid_thw)
outputs = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds.squeeze(0), inputs_embeds=inputs_embeds.squeeze(0),
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -500,6 +506,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
) )
hidden_states, _ = self.norm(hidden_states)
logits = self.lm_head(outputs) logits = self.lm_head(hidden_states)
return logits, None return logits, None