This commit is contained in:
Nicolas Patry 2023-05-15 16:43:32 +02:00
parent edc9ce9beb
commit 7ccb8eefdc
7 changed files with 181 additions and 137 deletions

View File

@ -29,6 +29,7 @@ from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
@ -91,12 +92,7 @@ class LlamaRMSNorm(nn.Module):
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(self, num_heads, hidden_size, process_group=None, quantize=None):
self,
num_heads,
hidden_size,
process_group=None,
):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -106,8 +102,12 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) self.query_key_value = FastLinear(
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) hidden_size, 3 * hidden_size, bias=False, quantize=quantize
)
self.o_proj = FastLinear(
hidden_size, hidden_size, bias=False, quantize=quantize
)
else: else:
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear( self.query_key_value = TensorParallelColumnLinear(
@ -115,12 +115,14 @@ class FlashLlamaAttention(torch.nn.Module):
3 * hidden_size, 3 * hidden_size,
bias=False, bias=False,
process_group=process_group, process_group=process_group,
quantize=quantize,
) )
self.o_proj = TensorParallelRowLinear( self.o_proj = TensorParallelRowLinear(
hidden_size, hidden_size,
hidden_size, hidden_size,
bias=False, bias=False,
process_group=process_group, process_group=process_group,
quantize=quantize,
) )
def forward( def forward(
@ -194,7 +196,9 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None): def __init__(
self, act, hidden_size, intermediate_size, process_group=None, quantize=None
):
super().__init__() super().__init__()
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
@ -210,9 +214,11 @@ class LlamaMLP(nn.Module):
if process_group is None: if process_group is None:
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = FastLinear( self.gate_up_proj = FastLinear(
hidden_size, 2 * intermediate_size, bias=False hidden_size, 2 * intermediate_size, bias=False, quantize=quantize
)
self.down_proj = FastLinear(
intermediate_size, hidden_size, bias=False, quantize=quantize
) )
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
else: else:
# Fuse gate and up proj # Fuse gate and up proj
@ -221,6 +227,7 @@ class LlamaMLP(nn.Module):
2 * intermediate_size, 2 * intermediate_size,
bias=False, bias=False,
process_group=process_group, process_group=process_group,
quantize=quantize,
) )
self.down_proj = TensorParallelRowLinear( self.down_proj = TensorParallelRowLinear(
intermediate_size, intermediate_size,
@ -228,6 +235,7 @@ class LlamaMLP(nn.Module):
bias=False, bias=False,
process_group=process_group, process_group=process_group,
reduce=True, reduce=True,
quantize=quantize,
) )
self.intermediate_size = self.down_proj.in_features self.intermediate_size = self.down_proj.in_features
@ -248,11 +256,16 @@ class FlashLlamaLayer(nn.Module):
intermediate_size, intermediate_size,
rms_norm_eps, rms_norm_eps,
process_group=None, process_group=None,
quantize=None,
): ):
super().__init__() super().__init__()
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) self.self_attn = FlashLlamaAttention(
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) num_heads, hidden_size, process_group, quantize=quantize
)
self.mlp = LlamaMLP(
act, hidden_size, intermediate_size, process_group, quantize=quantize
)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
@ -309,6 +322,7 @@ class FlashLlamaModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group config.vocab_size, config.hidden_size, process_group=process_group
) )
self.embed_tokens.add_null_idx()
else: else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
@ -321,6 +335,7 @@ class FlashLlamaModel(torch.nn.Module):
config.intermediate_size, config.intermediate_size,
config.rms_norm_eps, config.rms_norm_eps,
process_group, process_group,
quantize=config.quantize,
) )
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
@ -332,15 +347,15 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self, load_in_8bit: bool = False): # def post_load_weights(self, load_in_8bit: bool = False):
if isinstance(self.embed_tokens, TensorParallelEmbedding): # if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx() # self.embed_tokens.add_null_idx()
for layer in self.layers: # for layer in self.layers:
layer: FlashLlamaLayer # layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(load_in_8bit) # layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
layer.self_attn.o_proj.prepare_weights(load_in_8bit) # layer.self_attn.o_proj.prepare_weights(load_in_8bit)
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit) # layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
layer.mlp.down_proj.prepare_weights(load_in_8bit) # layer.mlp.down_proj.prepare_weights(load_in_8bit)
def forward( def forward(
self, self,
@ -429,9 +444,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else: else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, load_in_8bit: bool = False): # def post_load_weights(self, load_in_8bit: bool = False):
self.model.post_load_weights(load_in_8bit) # self.model.post_load_weights(load_in_8bit)
self.lm_head.prepare_weights() # self.lm_head.prepare_weights()
def forward( def forward(
self, self,

View File

@ -76,20 +76,20 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_size, hidden_size, process_group=process_group, reduce=reduce hidden_size, hidden_size, process_group=process_group, reduce=reduce
) )
def shuffle_qkv_dims(self): # def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute""" # """Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter( # self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view( # self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size # self.num_heads, 3, self.head_size, self.hidden_size
) # )
.permute(1, 0, 2, 3) # .permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size) # .reshape(-1, self.hidden_size)
) # )
self.query_key_value.bias = torch.nn.Parameter( # self.query_key_value.bias = torch.nn.Parameter(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size) # self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
.permute(1, 0, 2) # .permute(1, 0, 2)
.reshape(-1) # .reshape(-1)
) # )
def forward( def forward(
self, self,
@ -317,6 +317,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group config.vocab_size, config.hidden_size, process_group=process_group
) )
self.embed_in.add_null_idx()
else: else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
@ -345,28 +346,28 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self, load_in_8bit=False): # def post_load_weights(self, load_in_8bit=False):
if isinstance(self.embed_in, TensorParallelEmbedding): # if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx() # self.embed_in.add_null_idx()
for layer in self.layers: # for layer in self.layers:
layer: FlashNeoXLayer # layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims() # layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.prepare_weights(load_in_8bit) # layer.attention.query_key_value.prepare_weights(load_in_8bit)
layer.attention.dense.prepare_weights(load_in_8bit) # layer.attention.dense.prepare_weights(load_in_8bit)
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit) # layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit) # layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
@classmethod # @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained # # Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us # # to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False) # load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXModel, cls).from_pretrained( # model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs # pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) # )
model.post_load_weights(load_in_8bit) # model.post_load_weights(load_in_8bit)
return model # return model
def forward( def forward(
self, self,
@ -451,26 +452,30 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config.hidden_size, config.hidden_size,
config.vocab_size // process_group.size(), config.vocab_size // process_group.size(),
bias=False, bias=False,
quantize=config.quantize,
) )
else: else:
self.embed_out = FastLinear( self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False config.hidden_size,
config.vocab_size,
bias=False,
quantize=config.quantize,
) )
def post_load_weights(self, load_in_8bit=False): # def post_load_weights(self, load_in_8bit=False):
self.gpt_neox.post_load_weights(load_in_8bit) # self.gpt_neox.post_load_weights(load_in_8bit)
self.embed_out.prepare_weights() # self.embed_out.prepare_weights()
@classmethod # @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained # # Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us # # to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False) # load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( # model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs # pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) # )
model.post_load_weights(load_in_8bit) # model.post_load_weights(load_in_8bit)
return model # return model
def forward( def forward(
self, self,

View File

@ -24,6 +24,7 @@ class FlashMQAttention(torch.nn.Module):
num_heads, num_heads,
hidden_size, hidden_size,
process_group=None, process_group=None,
quantize=None,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
@ -33,15 +34,20 @@ class FlashMQAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: if process_group is None:
self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) self.c_attn = FastLinear(
self.c_proj = FastLinear(hidden_size, hidden_size) hidden_size, hidden_size + 2 * self.head_size, quantize=quantize
)
self.c_proj = FastLinear(hidden_size, hidden_size, quantize=quantize)
else: else:
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) self.c_attn = FastLinear(
hidden_size, self.head_size * (self.num_heads + 2), quantize=quantize
)
self.c_proj = TensorParallelRowLinear( self.c_proj = TensorParallelRowLinear(
hidden_size, hidden_size,
hidden_size, hidden_size,
process_group=process_group, process_group=process_group,
quantize=quantize,
) )
def forward( def forward(
@ -123,7 +129,9 @@ class FlashMQAttention(torch.nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None): def __init__(
self, act, hidden_size, intermediate_size, process_group=None, quantize=None
):
super().__init__() super().__init__()
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
@ -137,18 +145,20 @@ class MLP(nn.Module):
) )
if process_group is None: if process_group is None:
self.c_fc = FastLinear(hidden_size, intermediate_size) self.c_fc = FastLinear(hidden_size, intermediate_size, quantize=quantize)
self.c_proj = FastLinear(intermediate_size, hidden_size) self.c_proj = FastLinear(intermediate_size, hidden_size, quantize=quantize)
else: else:
self.c_fc = TensorParallelColumnLinear( self.c_fc = TensorParallelColumnLinear(
hidden_size, hidden_size,
intermediate_size, intermediate_size,
process_group=process_group, process_group=process_group,
quantize=quantize,
) )
self.c_proj = TensorParallelRowLinear( self.c_proj = TensorParallelRowLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
process_group=process_group, process_group=process_group,
quantize=quantize,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
@ -167,20 +177,20 @@ class Block(nn.Module):
intermediate_size, intermediate_size,
layer_norm_eps, layer_norm_eps,
process_group=None, process_group=None,
quantize=None,
): ):
super().__init__() super().__init__()
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attn = FlashMQAttention( self.attn = FlashMQAttention(
num_heads, num_heads, hidden_size, process_group, quantize=quantize
hidden_size,
process_group,
) )
self.mlp = MLP( self.mlp = MLP(
act, act,
hidden_size, hidden_size,
intermediate_size, intermediate_size,
process_group, process_group,
quantize=quantize,
) )
def forward( def forward(
@ -231,12 +241,14 @@ class FlashSantacoderModel(nn.Module):
reduce=False, reduce=False,
process_group=process_group, process_group=process_group,
) )
self.wte.add_null_idx()
self.wpe = TensorParallelEmbedding( self.wpe = TensorParallelEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.hidden_size, config.hidden_size,
reduce=False, reduce=False,
process_group=process_group, process_group=process_group,
) )
self.wpe.add_null_idx()
else: else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
@ -252,6 +264,7 @@ class FlashSantacoderModel(nn.Module):
else 4 * config.hidden_size, else 4 * config.hidden_size,
config.layer_norm_epsilon, config.layer_norm_epsilon,
process_group, process_group,
quantize=config.quantize,
) )
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
@ -261,16 +274,13 @@ class FlashSantacoderModel(nn.Module):
self.head_size = self.h[0].attn.head_size self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self, load_in_8bit: bool = False): # def post_load_weights(self, load_in_8bit: bool = False):
if self.tp_embeddings: # for layer in self.h:
self.wte.add_null_idx() # layer: Block
self.wpe.add_null_idx() # layer.attn.c_attn.prepare_weights(load_in_8bit)
for layer in self.h: # layer.attn.c_proj.prepare_weights(load_in_8bit)
layer: Block # layer.mlp.c_fc.prepare_weights(load_in_8bit)
layer.attn.c_attn.prepare_weights(load_in_8bit) # layer.mlp.c_proj.prepare_weights(load_in_8bit)
layer.attn.c_proj.prepare_weights(load_in_8bit)
layer.mlp.c_fc.prepare_weights(load_in_8bit)
layer.mlp.c_proj.prepare_weights(load_in_8bit)
def forward( def forward(
self, self,
@ -343,13 +353,16 @@ class FlashSantacoderForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
config.vocab_size // process_group.size(), config.vocab_size // process_group.size(),
bias=False, bias=False,
quantize=None,
) )
else: else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = FastLinear(
config.hidden_size, config.vocab_size, bias=False, quantize=None
)
def post_load_weights(self, load_in_8bit: bool = False): # def post_load_weights(self, load_in_8bit: bool = False):
self.transformer.post_load_weights(load_in_8bit) # self.transformer.post_load_weights(load_in_8bit)
self.lm_head.prepare_weights() # self.lm_head.prepare_weights()
def forward( def forward(
self, self,

View File

@ -140,7 +140,7 @@ class FlashLlama(FlashCausalLM):
del value del value
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) # model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama): class FlashLlamaSharded(FlashLlama):
@ -307,4 +307,4 @@ class FlashLlamaSharded(FlashLlama):
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) # model.post_load_weights(quantize)

View File

@ -152,4 +152,4 @@ class FlashNeoXSharded(FlashNeoX):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
model.post_load_weights(quantize) # model.post_load_weights(quantize)

View File

@ -160,7 +160,7 @@ class FlashSantacoder(FlashCausalLM):
del value del value
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) # model.post_load_weights(quantize)
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text # Do not skip special tokens as they are used for custom parsing rules of the generated text
@ -378,4 +378,4 @@ class FlashSantacoderSharded(FlashSantacoder):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) # model.post_load_weights(quantize)

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn from torch import nn
import dropout_layer_norm import torch.nn.functional as F
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
try: try:
@ -18,12 +18,11 @@ class FastLinear(nn.Linear):
bias: bool = True, bias: bool = True,
device=None, device=None,
dtype=None, dtype=None,
quantize=None,
) -> None: ) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) self.quantize = quantize
self.quantized = False
self.bnb_linear = None self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize == "bitsandbytes": if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
@ -33,6 +32,7 @@ class FastLinear(nn.Linear):
) )
self.quantized = True self.quantized = True
super().__init__(in_features, out_features, bias, device, dtype)
self.bnb_linear = Linear8bitLt( self.bnb_linear = Linear8bitLt(
self.in_features, self.in_features,
self.out_features, self.out_features,
@ -51,12 +51,13 @@ class FastLinear(nn.Linear):
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now") raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None: elif quantize is None:
super().__init__(in_features, out_features, bias, device, dtype)
self.weight = nn.Parameter(self.weight.T) self.weight = nn.Parameter(self.weight.T)
else: else:
raise ValueError(f"Unexpected quantize `{quantize}`") raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized: if self.quantize:
return self.bnb_linear(input) return self.bnb_linear(input)
else: else:
if self.bias is not None: if self.bias is not None:
@ -73,6 +74,7 @@ class TensorParallelColumnLinear(FastLinear):
bias=True, bias=True,
device=None, device=None,
dtype=None, dtype=None,
quantize=None,
): ):
self.process_group = process_group self.process_group = process_group
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
@ -85,6 +87,7 @@ class TensorParallelColumnLinear(FastLinear):
bias=bias, bias=bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
quantize=quantize,
) )
@ -98,6 +101,7 @@ class TensorParallelRowLinear(FastLinear):
bias=True, bias=True,
device=None, device=None,
dtype=None, dtype=None,
quantize=None,
): ):
self.process_group = process_group self.process_group = process_group
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
@ -111,6 +115,7 @@ class TensorParallelRowLinear(FastLinear):
bias=bias, bias=bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
quantize=quantize,
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
@ -182,40 +187,46 @@ class TensorParallelEmbedding(nn.Embedding):
return out return out
class FastLayerNorm(nn.LayerNorm): try:
def forward(self, hidden_states, residual=None): import dropout_layer_norm
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual class FastLayerNorm(nn.LayerNorm):
else: def forward(self, hidden_states, residual=None):
( if hidden_states.shape[-1] > 8192:
normed_hidden_states, if residual is not None:
residual, hidden_states += residual
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states residual = hidden_states
return normed_hidden_states, residual return super().forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
except ImportError:
pass
try: try: