fix: simplify syntax

This commit is contained in:
drbh 2024-03-21 03:33:56 +00:00
parent 5b076dfcf2
commit 704d4ddfaa

View File

@ -262,26 +262,15 @@ class GemmaFastRMSNorm(FastRMSNorm):
return cls(weight, eps)
# perform the multiplication in full precision and downcast after
def forward_downcast_after(self, hidden_states, residual=None):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = (hidden_states * self.weight).to(self.weight.dtype)
else:
hidden_states = hidden_states * self.weight
return hidden_states, residual
def forward(self, hidden_states, residual=None):
hidden_states, residual = self.forward_downcast_after(hidden_states, residual)
return hidden_states, residual
return hidden_states.to(self.weight.dtype), residual
def load_attention(config, prefix, weights):