mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: add norm after text output
This commit is contained in:
parent
aa2aa9f915
commit
65558b32f4
@ -31,6 +31,7 @@ import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
FastRMSNorm
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
@ -369,6 +370,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self.lm_head = FastLinear.load(
|
||||
prefix="lm_head", weights=weights, config=config, bias=False
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.device = weights.device
|
||||
|
||||
def get_position_ids(
|
||||
@ -488,7 +494,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
|
||||
position_ids = self.get_position_ids(input_ids, image_grid_thw)
|
||||
outputs = self.text_model(
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds.squeeze(0),
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
@ -500,6 +506,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs)
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits, None
|
||||
|
Loading…
Reference in New Issue
Block a user