feat: support force downcast after FastRMSNorm multiply

This commit is contained in:
drbh 2024-03-20 17:47:20 +00:00
parent dfbd9a39a2
commit b307fce653
2 changed files with 22 additions and 6 deletions

View File

@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig):
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu",
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
@ -473,7 +473,9 @@ class FlashGemmaLayer(nn.Module):
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
normed_hidden_states, res = self.input_layernorm(
hidden_states, residual, force_downcast_after=True
)
# Self Attention
attn_output = self.self_attn(
@ -490,7 +492,7 @@ class FlashGemmaLayer(nn.Module):
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
attn_output, res, force_downcast_after=True
)
mlp_output = self.mlp(normed_attn_res_output)

View File

@ -687,7 +687,7 @@ try:
weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
def forward(self, hidden_states, residual=None, force_downcast_after=False):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
@ -701,9 +701,23 @@ try:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
# perform the multiplication in float32 then cast back to half
if force_downcast_after:
hidden_states = (hidden_states * self.weight).to(
self.weight.dtype
)
else:
# cast to half before the multiplication
hidden_states = self.weight * hidden_states.to(
self.weight.dtype
)
# avoid converting to half and multiply in float32
else:
hidden_states = self.weight * hidden_states
return hidden_states, residual
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
(