diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index e91927df..a23c4a5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig): num_attention_heads=16, num_key_value_heads=16, head_dim=256, - hidden_act="gelu", + hidden_act="gelu_pytorch_tanh", max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, @@ -473,7 +473,9 @@ class FlashGemmaLayer(nn.Module): input_lengths, max_s, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + normed_hidden_states, res = self.input_layernorm( + hidden_states, residual, force_downcast_after=True + ) # Self Attention attn_output = self.self_attn( @@ -490,7 +492,7 @@ class FlashGemmaLayer(nn.Module): # faster post attention rms norm normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res + attn_output, res, force_downcast_after=True ) mlp_output = self.mlp(normed_attn_res_output) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 209f1c8a..421326cb 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -687,7 +687,7 @@ try: weight = weights.get_tensor(f"{prefix}.weight") return cls(weight, eps) - def forward(self, hidden_states, residual=None): + def forward(self, hidden_states, residual=None, force_downcast_after=False): if hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual @@ -701,9 +701,23 @@ try: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) + # perform the multiplication in float32 then cast back to half + if force_downcast_after: + hidden_states = (hidden_states * self.weight).to( + self.weight.dtype + ) + else: + # cast to half before the multiplication + hidden_states = self.weight * hidden_states.to( + self.weight.dtype + ) + + # avoid converting to half and multiply in float32 + else: + hidden_states = self.weight * hidden_states + + return hidden_states, residual - return self.weight * hidden_states, residual elif IS_CUDA_SYSTEM: # faster post attention rms norm (