This commit is contained in:
OlivierDehaene 2023-04-05 16:29:34 +02:00
parent f26dfd0dc1
commit e8a3ec36c3
4 changed files with 382 additions and 26 deletions

View File

@ -20,6 +20,8 @@ try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_santacoder import FlashSantacoder from text_generation_server.models.flash_santacoder import FlashSantacoder
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
from text_generation_server.models.flash_santacoder import FlashSantacoder, FlashSantacoderSharded
FLASH_ATTENTION = torch.cuda.is_available() FLASH_ATTENTION = torch.cuda.is_available()
except ImportError: except ImportError:
@ -49,6 +51,7 @@ if FLASH_ATTENTION:
__all__.append(FlashNeoX) __all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashSantacoder) __all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
__all__.append(FlashLlamaSharded) __all__.append(FlashLlamaSharded)
@ -78,9 +81,11 @@ def get_model(
else: else:
return Galactica(model_id, revision, quantize=quantize) return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id: if "bigcode" in model_id:
if sharded: if sharded:
raise NotImplementedError("sharded is not supported for Santacoder") if not FLASH_ATTENTION:
raise NotImplementedError("sharded is not supported for Santacoder when FLASH_ATTENTION=0")
return FlashSantacoderSharded(model_id, revision=revision)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize) return santacoder_cls(model_id, revision, quantize)

View File

@ -1,6 +1,8 @@
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
@ -65,6 +67,127 @@ class FastLinear(nn.Linear):
return torch.matmul(input, self.weight) return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
reduce=True,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.reduce = reduce
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class FlashMQAttention(torch.nn.Module): class FlashMQAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -80,10 +203,18 @@ class FlashMQAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: if process_group is None:
self.attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
self.c_proj = FastLinear(hidden_size, hidden_size) self.c_proj = FastLinear(hidden_size, hidden_size)
else: else:
raise NotImplementedError self.num_heads = self.num_heads // process_group.size()
self.hidden_size = self.hidden_size // process_group.size()
self.c_attn = FastLinear(
hidden_size,
self.head_size * (self.num_heads + 2)
)
self.c_proj = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group, reduce=True
)
def forward( def forward(
self, self,
@ -94,10 +225,10 @@ class FlashMQAttention(torch.nn.Module):
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
qkv = self.attn(hidden_states) qkv = self.c_attn(hidden_states)
# Split query from key_value # Split query from key_value
query, key_value = qkv.split([self.hidden_size, 2 * self.head_size], dim=1) query, key_value = qkv.split([self.head_size * self.num_heads, 2 * self.head_size], dim=1)
# Prepare query and key_value for indexing # Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
@ -179,7 +310,17 @@ class MLP(nn.Module):
self.c_fc = FastLinear(hidden_size, intermediate_size) self.c_fc = FastLinear(hidden_size, intermediate_size)
self.c_proj = FastLinear(intermediate_size, hidden_size) self.c_proj = FastLinear(intermediate_size, hidden_size)
else: else:
raise NotImplementedError self.c_fc = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.c_proj = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
reduce=False,
)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states) hidden_states = self.c_fc(hidden_states)
@ -246,11 +387,28 @@ class FlashSantacoderModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.process_group = process_group
self.tp_embeddings = False
if process_group is not None: if process_group is not None:
raise NotImplementedError self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
self.wte = nn.Embedding(config.vocab_size, config.hidden_size) if self.tp_embeddings:
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.wte = TensorParallelEmbedding(
config.vocab_size,
config.hidden_size,
process_group=process_group,
)
self.wpe = TensorParallelEmbedding(
config.max_position_embeddings,
config.hidden_size,
process_group=process_group,
)
else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
@ -273,9 +431,12 @@ class FlashSantacoderModel(nn.Module):
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self): def post_load_weights(self):
if self.tp_embeddings:
self.wte.add_null_idx()
self.wpe.add_null_idx()
for layer in self.h: for layer in self.h:
layer: Block layer: Block
layer.attn.attn.transpose_weight() layer.attn.c_attn.transpose_weight()
layer.attn.c_proj.transpose_weight() layer.attn.c_proj.transpose_weight()
layer.mlp.c_fc.transpose_weight() layer.mlp.c_fc.transpose_weight()
layer.mlp.c_proj.transpose_weight() layer.mlp.c_proj.transpose_weight()
@ -289,6 +450,8 @@ class FlashSantacoderModel(nn.Module):
past_key_values=None, past_key_values=None,
): ):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.tp_embeddings:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
@ -335,7 +498,14 @@ class FlashSantacoderForCausalLM(nn.Module):
self.transformer = FlashSantacoderModel(config, process_group) self.transformer = FlashSantacoderModel(config, process_group)
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self): def post_load_weights(self):
self.transformer.post_load_weights() self.transformer.post_load_weights()
@ -352,4 +522,18 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, position_ids, cu_seqlens, max_s, past_key_values input_ids, position_ids, cu_seqlens, max_s, past_key_values
) )
return self.lm_head(hidden_states), present logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits, logits, group=self.transformer.process_group
)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -5,7 +5,7 @@ from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, List from typing import Optional, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
@ -63,7 +63,6 @@ class FlashNeoXSharded(FlashNeoX):
self.load_weights( self.load_weights(
model, model,
filenames, filenames,
quantize=quantize,
device=device, device=device,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
@ -80,16 +79,13 @@ class FlashNeoXSharded(FlashNeoX):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool,
device: torch.device, device: torch.device,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
parameters = dict(model.named_parameters()) parameters = dict(model.named_parameters())
for file in filenames: for file in filenames:
with safe_open( with safe_open(file, framework="pt", device=str(device)) as f:
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys(): for name in f.keys():
module_name, param_name = name.rsplit(".", 1) module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name) module = model.get_submodule(module_name)

View File

@ -3,15 +3,20 @@ import torch.distributed
from accelerate import init_empty_weights from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open
from pathlib import Path from pathlib import Path
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, GPT2Config
from typing import Optional, List from typing import Optional, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM, FlashSantacoderForCausalLM,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed,
weight_files, weight_files,
download_weights, download_weights,
weight_hub_files, weight_hub_files,
@ -36,10 +41,9 @@ class FlashSantacoder(FlashCausalLM):
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = GPT2Config.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=True, # Needed as the config is not part of Transformers
) )
# We do not use from_pretrained as we modified the model internal module layout # We do not use from_pretrained as we modified the model internal module layout
@ -82,9 +86,9 @@ class FlashSantacoder(FlashCausalLM):
# Fused qkv # Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key: if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".attn.weight" final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key: elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".attn.bias" final_key = layer_name + ".c_attn.bias"
else: else:
final_key = key final_key = key
@ -103,13 +107,14 @@ class FlashSantacoder(FlashCausalLM):
or "c_proj.weight" in key or "c_proj.weight" in key
or "q_attn.weight" in key or "q_attn.weight" in key
or "kv_attn.weight" in key or "kv_attn.weight" in key
or "c_attn.weight" in key
): ):
# Tranpose as we use nn.Linear instead of Conv1D # Tranpose as we use nn.Linear instead of Conv1D
value = value.T value = value.T
if current_parameter_tensor.device == torch.device("meta"): if current_parameter_tensor.device == torch.device("meta"):
# Init qkv # Init qkv
if "attn.weight" in final_key: if "c_attn.weight" in final_key:
module._parameters[param_name] = value.new_empty( module._parameters[param_name] = value.new_empty(
( (
model.transformer.head_size model.transformer.head_size
@ -117,7 +122,7 @@ class FlashSantacoder(FlashCausalLM):
value.shape[1], value.shape[1],
) )
) )
elif "attn.bias" in final_key: elif "c_attn.bias" in final_key:
module._parameters[param_name] = value.new_empty( module._parameters[param_name] = value.new_empty(
( (
model.transformer.head_size model.transformer.head_size
@ -157,3 +162,169 @@ class FlashSantacoder(FlashCausalLM):
return self.tokenizer.decode( return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
) )
class FlashSantacoderSharded(FlashSantacoder):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
if quantize:
raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True, # Needed as the config is not part of Transformers
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashSantacoderForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
device=device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
device: torch.device,
rank: int,
world_size: int,
):
for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f:
for name in f.keys():
slice_ = f.get_slice(name)
layer_name = ".".join(name.split(".")[:4])
# Fused qkv
if "q_attn.weight" in name or "kv_attn.weight" in name:
final_name = layer_name + ".c_attn.weight"
elif "q_attn.bias" in name or "kv_attn.bias" in name:
final_name = layer_name + ".c_attn.bias"
else:
final_name = name
module_name, param_name = final_name.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif "c_attn" in name:
size = slice_.get_shape()[0]
block_size =
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
tensor = tensor.contiguous()
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_name:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
tensor.shape[1],
)
)
elif "c_attn.bias" in final_name:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn.weight" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "q_attn.bias" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in name:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "kv_attn.bias" in name:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights()