diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index e91927df..69c1665d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig): num_attention_heads=16, num_key_value_heads=16, head_dim=256, - hidden_act="gelu", + hidden_act="gelu_pytorch_tanh", max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, @@ -261,6 +261,17 @@ class GemmaFastRMSNorm(FastRMSNorm): weight = weights.get_tensor(f"{prefix}.weight") + 1 return cls(weight, eps) + # perform the multiplication in full precision and downcast after + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * self.weight + return hidden_states.to(self.weight.dtype), residual + def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: