diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index a23c4a5c..3eedb766 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -261,6 +261,28 @@ class GemmaFastRMSNorm(FastRMSNorm): weight = weights.get_tensor(f"{prefix}.weight") + 1 return cls(weight, eps) + # perform the multiplication in full precision and downcast after + def forward_downcast_after(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = (hidden_states * self.weight).to(self.weight.dtype) + else: + hidden_states = hidden_states * self.weight + + return hidden_states, residual + + def forward(self, hidden_states, residual=None): + hidden_states, residual = self.forward_downcast_after(hidden_states, residual) + return hidden_states, residual + def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: @@ -473,9 +495,7 @@ class FlashGemmaLayer(nn.Module): input_lengths, max_s, ): - normed_hidden_states, res = self.input_layernorm( - hidden_states, residual, force_downcast_after=True - ) + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention attn_output = self.self_attn( @@ -492,7 +512,7 @@ class FlashGemmaLayer(nn.Module): # faster post attention rms norm normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res, force_downcast_after=True + attn_output, res ) mlp_output = self.mlp(normed_attn_res_output) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 421326cb..209f1c8a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -687,7 +687,7 @@ try: weight = weights.get_tensor(f"{prefix}.weight") return cls(weight, eps) - def forward(self, hidden_states, residual=None, force_downcast_after=False): + def forward(self, hidden_states, residual=None): if hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual @@ -701,23 +701,9 @@ try: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: - # perform the multiplication in float32 then cast back to half - if force_downcast_after: - hidden_states = (hidden_states * self.weight).to( - self.weight.dtype - ) - else: - # cast to half before the multiplication - hidden_states = self.weight * hidden_states.to( - self.weight.dtype - ) - - # avoid converting to half and multiply in float32 - else: - hidden_states = self.weight * hidden_states - - return hidden_states, residual + hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states, residual elif IS_CUDA_SYSTEM: # faster post attention rms norm (