2024-05-13 10:44:30 +00:00
import torch
from torch import nn
from accelerate import init_empty_weights
from text_generation_server . utils . import_utils import (
SYSTEM ,
)
# Monkey patching
@classmethod
def load_layer_norm ( cls , prefix , weights , eps ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
bias = weights . get_tensor ( f " { prefix } .bias " )
with init_empty_weights ( ) :
ln = cls ( weight . shape , eps = eps )
ln . weight = torch . nn . Parameter ( weight )
ln . bias = torch . nn . Parameter ( bias )
return ln
@classmethod
def load_layer_norm_no_bias ( cls , prefix , weights , eps ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
with init_empty_weights ( ) :
ln = cls ( weight . shape , eps = eps )
ln . weight = torch . nn . Parameter ( weight )
ln . bias = None
return ln
torch . nn . LayerNorm . load = load_layer_norm
torch . nn . LayerNorm . load_no_bias = load_layer_norm_no_bias
if SYSTEM == " cuda " :
import dropout_layer_norm
class FastLayerNorm ( nn . LayerNorm ) :
def forward ( self , hidden_states , residual = None ) :
if hidden_states . shape [ - 1 ] > 8192 :
if residual is not None :
hidden_states + = residual
residual = hidden_states
return super ( FastLayerNorm , self ) . forward ( hidden_states ) , residual
else :
(
normed_hidden_states ,
residual ,
* rest ,
) = dropout_layer_norm . dropout_add_ln_fwd (
hidden_states ,
residual ,
self . weight ,
self . bias ,
None ,
None ,
None ,
None ,
0.0 ,
self . eps ,
1.0 ,
0 ,
None ,
False ,
False ,
)
if residual is None :
residual = hidden_states
return normed_hidden_states , residual
elif SYSTEM == " rocm " :
2024-12-18 11:44:42 +00:00
import vllm . _custom_ops as ops
2024-05-13 10:44:30 +00:00
class FastLayerNorm ( nn . LayerNorm ) :
def forward ( self , hidden_states , residual = None ) :
if residual is not None :
hidden_states + = residual
residual = hidden_states
return super ( ) . forward ( hidden_states ) , residual
2024-06-25 11:20:57 +00:00
elif SYSTEM == " ipex " :
2024-05-13 10:44:30 +00:00
import intel_extension_for_pytorch as ipex
class FastLayerNorm ( nn . LayerNorm ) :
def forward ( self , hidden_states , residual = None ) :
out = ipex . llm . functional . add_layer_norm (
2024-06-25 10:21:29 +00:00
residual ,
hidden_states ,
self . weight ,
self . bias ,
self . eps ,
residual is not None ,
2024-05-13 10:44:30 +00:00
)
2024-06-25 10:21:29 +00:00
return out , residual if residual is not None else hidden_states
2024-05-13 10:44:30 +00:00
class FastRMSNorm ( nn . Module ) :
def __init__ ( self , weight : torch . Tensor , eps : float ) :
super ( ) . __init__ ( )
self . weight = nn . Parameter ( weight )
self . variance_epsilon = eps
@classmethod
def load ( cls , prefix , weights , eps = 1e-6 ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
return cls ( weight , eps )
def forward ( self , hidden_states , residual = None ) :
2024-06-25 11:20:57 +00:00
if SYSTEM == " ipex " :
2024-05-13 10:44:30 +00:00
out = ipex . llm . functional . add_rms_norm (
residual ,
hidden_states ,
self . weight ,
None ,
self . variance_epsilon ,
2024-06-25 10:21:29 +00:00
residual is not None ,
2024-05-13 10:44:30 +00:00
)
2024-06-25 10:21:29 +00:00
return out , residual if residual is not None else hidden_states
2024-12-18 11:44:42 +00:00
elif SYSTEM == " rocm " :
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None :
ops . fused_add_rms_norm (
hidden_states ,
residual ,
self . weight . data ,
self . variance_epsilon ,
)
return hidden_states , residual
residual = hidden_states
out = torch . empty_like ( hidden_states )
ops . rms_norm (
out ,
hidden_states ,
self . weight . data ,
self . variance_epsilon ,
)
return out , residual
2024-05-13 10:44:30 +00:00
elif hidden_states . shape [ - 1 ] > 8192 :
if residual is not None :
hidden_states + = residual
residual = hidden_states
hidden_states = hidden_states . to ( torch . float32 )
variance = hidden_states . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = hidden_states * torch . rsqrt (
variance + self . variance_epsilon
)
# convert into half-precision if necessary
if self . weight . dtype in [ torch . float16 , torch . bfloat16 ] :
hidden_states = hidden_states . to ( self . weight . dtype )
return self . weight * hidden_states , residual
elif SYSTEM == " cuda " :
# faster post attention rms norm
(
normed_hidden_states ,
res ,
* rest ,
) = dropout_layer_norm . dropout_add_ln_fwd (
hidden_states ,
residual ,
self . weight ,
None ,
None ,
None ,
None ,
None ,
0.0 ,
self . variance_epsilon ,
1.0 ,
0 ,
None ,
False ,
True , # Activate RMSNorm
)
if res is None :
res = hidden_states
return normed_hidden_states , res
else :
raise ValueError (
" Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction. "
)