feat: support force downcast after FastRMSNorm multiply

This commit is contained in:
drbh 2024-03-20 17:47:20 +00:00
parent dfbd9a39a2
commit b307fce653
2 changed files with 22 additions and 6 deletions

View File

@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig):
num_attention_heads=16, num_attention_heads=16,
num_key_value_heads=16, num_key_value_heads=16,
head_dim=256, head_dim=256,
hidden_act="gelu", hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192, max_position_embeddings=8192,
initializer_range=0.02, initializer_range=0.02,
rms_norm_eps=1e-6, rms_norm_eps=1e-6,
@ -473,7 +473,9 @@ class FlashGemmaLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(
hidden_states, residual, force_downcast_after=True
)
# Self Attention # Self Attention
attn_output = self.self_attn( attn_output = self.self_attn(
@ -490,7 +492,7 @@ class FlashGemmaLayer(nn.Module):
# faster post attention rms norm # faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm( normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res attn_output, res, force_downcast_after=True
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output)

View File

@ -687,7 +687,7 @@ try:
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None, force_downcast_after=False):
if hidden_states.shape[-1] > 8192: if hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
@ -701,9 +701,23 @@ try:
# convert into half-precision if necessary # convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) # perform the multiplication in float32 then cast back to half
if force_downcast_after:
hidden_states = (hidden_states * self.weight).to(
self.weight.dtype
)
else:
# cast to half before the multiplication
hidden_states = self.weight * hidden_states.to(
self.weight.dtype
)
# avoid converting to half and multiply in float32
else:
hidden_states = self.weight * hidden_states
return hidden_states, residual
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM: elif IS_CUDA_SYSTEM:
# faster post attention rms norm # faster post attention rms norm
( (