From 704d4ddfaa33bb509eb432a71f29895189e16e5f Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 21 Mar 2024 03:33:56 +0000 Subject: [PATCH] fix: simplify syntax --- .../custom_modeling/flash_gemma_modeling.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 3eedb766..69c1665d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -262,26 +262,15 @@ class GemmaFastRMSNorm(FastRMSNorm): return cls(weight, eps) # perform the multiplication in full precision and downcast after - def forward_downcast_after(self, hidden_states, residual=None): + def forward(self, hidden_states, residual=None): if residual is not None: hidden_states += residual residual = hidden_states - hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = (hidden_states * self.weight).to(self.weight.dtype) - else: - hidden_states = hidden_states * self.weight - - return hidden_states, residual - - def forward(self, hidden_states, residual=None): - hidden_states, residual = self.forward_downcast_after(hidden_states, residual) - return hidden_states, residual + hidden_states = hidden_states * self.weight + return hidden_states.to(self.weight.dtype), residual def load_attention(config, prefix, weights):