mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: simplify syntax
This commit is contained in:
parent
5b076dfcf2
commit
704d4ddfaa
@ -262,26 +262,15 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
# perform the multiplication in full precision and downcast after
|
# perform the multiplication in full precision and downcast after
|
||||||
def forward_downcast_after(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = (hidden_states * self.weight).to(self.weight.dtype)
|
|
||||||
else:
|
|
||||||
hidden_states = hidden_states * self.weight
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states.to(self.weight.dtype), residual
|
||||||
return hidden_states, residual
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
|
||||||
hidden_states, residual = self.forward_downcast_after(hidden_states, residual)
|
|
||||||
return hidden_states, residual
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
Loading…
Reference in New Issue
Block a user