mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: add norm after text output
This commit is contained in:
parent
aa2aa9f915
commit
65558b32f4
@ -31,6 +31,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
|
FastRMSNorm
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -369,6 +370,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
self.lm_head = FastLinear.load(
|
self.lm_head = FastLinear.load(
|
||||||
prefix="lm_head", weights=weights, config=config, bias=False
|
prefix="lm_head", weights=weights, config=config, bias=False
|
||||||
)
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix="model.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
self.device = weights.device
|
self.device = weights.device
|
||||||
|
|
||||||
def get_position_ids(
|
def get_position_ids(
|
||||||
@ -488,7 +494,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||||
|
|
||||||
position_ids = self.get_position_ids(input_ids, image_grid_thw)
|
position_ids = self.get_position_ids(input_ids, image_grid_thw)
|
||||||
outputs = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds.squeeze(0),
|
inputs_embeds=inputs_embeds.squeeze(0),
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
@ -500,6 +506,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
logits = self.lm_head(outputs)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits, None
|
return logits, None
|
||||||
|
Loading…
Reference in New Issue
Block a user