mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: prefer gemma specific rms
This commit is contained in:
parent
b307fce653
commit
5b076dfcf2
@ -261,6 +261,28 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
|||||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
|
# perform the multiplication in full precision and downcast after
|
||||||
|
def forward_downcast_after(self, hidden_states, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
|
||||||
|
# convert into half-precision if necessary
|
||||||
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
hidden_states = (hidden_states * self.weight).to(self.weight.dtype)
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
hidden_states, residual = self.forward_downcast_after(hidden_states, residual)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
@ -473,9 +495,7 @@ class FlashGemmaLayer(nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
hidden_states, residual, force_downcast_after=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output = self.self_attn(
|
attn_output = self.self_attn(
|
||||||
@ -492,7 +512,7 @@ class FlashGemmaLayer(nn.Module):
|
|||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
attn_output, res, force_downcast_after=True
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output)
|
||||||
|
@ -687,7 +687,7 @@ try:
|
|||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None, force_downcast_after=False):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192:
|
if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
@ -701,23 +701,9 @@ try:
|
|||||||
|
|
||||||
# convert into half-precision if necessary
|
# convert into half-precision if necessary
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
# perform the multiplication in float32 then cast back to half
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
if force_downcast_after:
|
|
||||||
hidden_states = (hidden_states * self.weight).to(
|
|
||||||
self.weight.dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# cast to half before the multiplication
|
|
||||||
hidden_states = self.weight * hidden_states.to(
|
|
||||||
self.weight.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# avoid converting to half and multiply in float32
|
|
||||||
else:
|
|
||||||
hidden_states = self.weight * hidden_states
|
|
||||||
|
|
||||||
return hidden_states, residual
|
|
||||||
|
|
||||||
|
return self.weight * hidden_states, residual
|
||||||
elif IS_CUDA_SYSTEM:
|
elif IS_CUDA_SYSTEM:
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
(
|
(
|
||||||
|
Loading…
Reference in New Issue
Block a user