feat: prefer gemma specific rms

This commit is contained in:
drbh 2024-03-21 03:28:03 +00:00
parent b307fce653
commit 5b076dfcf2
2 changed files with 27 additions and 21 deletions

View File

@ -261,6 +261,28 @@ class GemmaFastRMSNorm(FastRMSNorm):
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps) return cls(weight, eps)
# perform the multiplication in full precision and downcast after
def forward_downcast_after(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = (hidden_states * self.weight).to(self.weight.dtype)
else:
hidden_states = hidden_states * self.weight
return hidden_states, residual
def forward(self, hidden_states, residual=None):
hidden_states, residual = self.forward_downcast_after(hidden_states, residual)
return hidden_states, residual
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
@ -473,9 +495,7 @@ class FlashGemmaLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm( normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
hidden_states, residual, force_downcast_after=True
)
# Self Attention # Self Attention
attn_output = self.self_attn( attn_output = self.self_attn(
@ -492,7 +512,7 @@ class FlashGemmaLayer(nn.Module):
# faster post attention rms norm # faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm( normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res, force_downcast_after=True attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output)

View File

@ -687,7 +687,7 @@ try:
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None, force_downcast_after=False): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: if hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
@ -701,23 +701,9 @@ try:
# convert into half-precision if necessary # convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
# perform the multiplication in float32 then cast back to half hidden_states = hidden_states.to(self.weight.dtype)
if force_downcast_after:
hidden_states = (hidden_states * self.weight).to(
self.weight.dtype
)
else:
# cast to half before the multiplication
hidden_states = self.weight * hidden_states.to(
self.weight.dtype
)
# avoid converting to half and multiply in float32
else:
hidden_states = self.weight * hidden_states
return hidden_states, residual
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM: elif IS_CUDA_SYSTEM:
# faster post attention rms norm # faster post attention rms norm
( (