mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support force downcast after FastRMSNorm multiply
This commit is contained in:
parent
dfbd9a39a2
commit
b307fce653
@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig):
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu",
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
@ -473,7 +473,9 @@ class FlashGemmaLayer(nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
normed_hidden_states, res = self.input_layernorm(
|
||||
hidden_states, residual, force_downcast_after=True
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
@ -490,7 +492,7 @@ class FlashGemmaLayer(nn.Module):
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
attn_output, res, force_downcast_after=True
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
@ -687,7 +687,7 @@ try:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
def forward(self, hidden_states, residual=None, force_downcast_after=False):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
@ -701,9 +701,23 @@ try:
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
# perform the multiplication in float32 then cast back to half
|
||||
if force_downcast_after:
|
||||
hidden_states = (hidden_states * self.weight).to(
|
||||
self.weight.dtype
|
||||
)
|
||||
else:
|
||||
# cast to half before the multiplication
|
||||
hidden_states = self.weight * hidden_states.to(
|
||||
self.weight.dtype
|
||||
)
|
||||
|
||||
# avoid converting to half and multiply in float32
|
||||
else:
|
||||
hidden_states = self.weight * hidden_states
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
(
|
||||
|
Loading…
Reference in New Issue
Block a user