feat(server): support sharded santacoder (#167)

This commit is contained in:
OlivierDehaene 2023-04-12 17:18:08 +02:00 committed by GitHub
parent 5fa8ae041c
commit 880a76eed5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 462 additions and 48 deletions

View File

@ -18,8 +18,11 @@ from text_generation_server.models.t5 import T5Sharded
try: try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_santacoder import FlashSantacoder
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
from text_generation_server.models.flash_santacoder import (
FlashSantacoder,
FlashSantacoderSharded,
)
FLASH_ATTENTION = torch.cuda.is_available() FLASH_ATTENTION = torch.cuda.is_available()
except ImportError: except ImportError:
@ -49,6 +52,7 @@ if FLASH_ATTENTION:
__all__.append(FlashNeoX) __all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashSantacoder) __all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
__all__.append(FlashLlamaSharded) __all__.append(FlashLlamaSharded)
@ -78,9 +82,13 @@ def get_model(
else: else:
return Galactica(model_id, revision, quantize=quantize) return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id: if "bigcode" in model_id:
if sharded: if sharded:
raise NotImplementedError("sharded is not supported for Santacoder") if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision=revision)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize) return santacoder_cls(model_id, revision, quantize)

View File

@ -93,10 +93,11 @@ class BLOOMSharded(BLOOM):
filenames, filenames,
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
self.model = model.eval().to(dtype) self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=1 tokenizer=tokenizer, device=device, decode_buffer=1
@ -108,6 +109,7 @@ class BLOOMSharded(BLOOM):
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
@ -157,7 +159,7 @@ class BLOOMSharded(BLOOM):
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous() tensor = tensor.contiguous().to(dtype)
if quantize: if quantize:
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:

View File

@ -373,7 +373,7 @@ class LlamaMLP(nn.Module):
x, x,
approximate="tanh" approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"] if act in ["gelu_fast", "gelu_pytorch_tanh"]
else None, else "none",
) )
) )

View File

@ -376,7 +376,12 @@ class FlashMLP(nn.Module):
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh") else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
) )
if process_group is None: if process_group is None:

View File

@ -1,6 +1,8 @@
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
@ -65,6 +67,127 @@ class FastLinear(nn.Linear):
return torch.matmul(input, self.weight) return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
reduce=True,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.reduce = reduce
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class FlashMQAttention(torch.nn.Module): class FlashMQAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -80,10 +203,16 @@ class FlashMQAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: if process_group is None:
self.attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
self.c_proj = FastLinear(hidden_size, hidden_size) self.c_proj = FastLinear(hidden_size, hidden_size)
else: else:
raise NotImplementedError self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
self.c_proj = TensorParallelRowLinear(
hidden_size,
hidden_size,
process_group=process_group,
)
def forward( def forward(
self, self,
@ -94,10 +223,12 @@ class FlashMQAttention(torch.nn.Module):
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
qkv = self.attn(hidden_states) qkv = self.c_attn(hidden_states)
# Split query from key_value # Split query from key_value
query, key_value = qkv.split([self.hidden_size, 2 * self.head_size], dim=1) query, key_value = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size], dim=1
)
# Prepare query and key_value for indexing # Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
@ -171,7 +302,7 @@ class MLP(nn.Module):
x, x,
approximate="tanh" approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"] if act in ["gelu_fast", "gelu_pytorch_tanh"]
else None, else "none",
) )
) )
@ -179,7 +310,16 @@ class MLP(nn.Module):
self.c_fc = FastLinear(hidden_size, intermediate_size) self.c_fc = FastLinear(hidden_size, intermediate_size)
self.c_proj = FastLinear(intermediate_size, hidden_size) self.c_proj = FastLinear(intermediate_size, hidden_size)
else: else:
raise NotImplementedError self.c_fc = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.c_proj = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states) hidden_states = self.c_fc(hidden_states)
@ -246,11 +386,30 @@ class FlashSantacoderModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.process_group = process_group
self.tp_embeddings = False
if process_group is not None: if process_group is not None:
raise NotImplementedError self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
self.wte = nn.Embedding(config.vocab_size, config.hidden_size) if self.tp_embeddings:
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.wte = TensorParallelEmbedding(
config.vocab_size,
config.hidden_size,
reduce=False,
process_group=process_group,
)
self.wpe = TensorParallelEmbedding(
config.max_position_embeddings,
config.hidden_size,
reduce=False,
process_group=process_group,
)
else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
@ -273,9 +432,12 @@ class FlashSantacoderModel(nn.Module):
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self): def post_load_weights(self):
if self.tp_embeddings:
self.wte.add_null_idx()
self.wpe.add_null_idx()
for layer in self.h: for layer in self.h:
layer: Block layer: Block
layer.attn.attn.transpose_weight() layer.attn.c_attn.transpose_weight()
layer.attn.c_proj.transpose_weight() layer.attn.c_proj.transpose_weight()
layer.mlp.c_fc.transpose_weight() layer.mlp.c_fc.transpose_weight()
layer.mlp.c_proj.transpose_weight() layer.mlp.c_proj.transpose_weight()
@ -289,6 +451,8 @@ class FlashSantacoderModel(nn.Module):
past_key_values=None, past_key_values=None,
): ):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.tp_embeddings:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
@ -335,7 +499,14 @@ class FlashSantacoderForCausalLM(nn.Module):
self.transformer = FlashSantacoderModel(config, process_group) self.transformer = FlashSantacoderModel(config, process_group)
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self): def post_load_weights(self):
self.transformer.post_load_weights() self.transformer.post_load_weights()
@ -352,4 +523,18 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, position_ids, cu_seqlens, max_s, past_key_values input_ids, position_ids, cu_seqlens, max_s, past_key_values
) )
return self.lm_head(hidden_states), present logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits, logits, group=self.transformer.process_group
)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -5,7 +5,7 @@ from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, List from typing import Optional, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
@ -63,13 +63,13 @@ class FlashNeoXSharded(FlashNeoX):
self.load_weights( self.load_weights(
model, model,
filenames, filenames,
quantize=quantize,
device=device, device=device,
dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
model.post_load_weights() model.post_load_weights()
self.model = model.eval().to(dtype) self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -80,16 +80,14 @@ class FlashNeoXSharded(FlashNeoX):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
parameters = dict(model.named_parameters()) parameters = dict(model.named_parameters())
for file in filenames: for file in filenames:
with safe_open( with safe_open(file, framework="pt", device=str(device)) as f:
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys(): for name in f.keys():
module_name, param_name = name.rsplit(".", 1) module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name) module = model.get_submodule(module_name)
@ -142,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX):
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous() tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None: if current_parameter_tensor is not None:
module._parameters[param_name] = tensor module._parameters[param_name] = tensor

View File

@ -3,15 +3,20 @@ import torch.distributed
from accelerate import init_empty_weights from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open
from pathlib import Path from pathlib import Path
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, GPT2Config
from typing import Optional, List from typing import Optional, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM, FlashSantacoderForCausalLM,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed,
weight_files, weight_files,
download_weights, download_weights,
weight_hub_files, weight_hub_files,
@ -36,10 +41,9 @@ class FlashSantacoder(FlashCausalLM):
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = GPT2Config.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=True, # Needed as the config is not part of Transformers
) )
# We do not use from_pretrained as we modified the model internal module layout # We do not use from_pretrained as we modified the model internal module layout
@ -54,12 +58,9 @@ class FlashSantacoder(FlashCausalLM):
model = FlashSantacoderForCausalLM(config) model = FlashSantacoderForCausalLM(config)
self.load_weights( self.load_weights(
model, model, filenames, device, dtype, config.architectures[0].startswith("GPT2")
filenames,
device,
dtype,
) )
self.model = model.eval().to(device).to(dtype) self.model = model.eval()
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=1 tokenizer=tokenizer, device=device, decode_buffer=1
@ -71,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
filenames: List[Path], filenames: List[Path],
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
transpose: bool,
): ):
for filename in filenames: for filename in filenames:
state_dict = torch.load(filename, map_location="cpu") state_dict = torch.load(filename, map_location="cpu")
@ -81,9 +83,9 @@ class FlashSantacoder(FlashCausalLM):
# Fused qkv # Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key: if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".attn.weight" final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key: elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".attn.bias" final_key = layer_name + ".c_attn.bias"
else: else:
final_key = key final_key = key
@ -97,18 +99,19 @@ class FlashSantacoder(FlashCausalLM):
current_parameter_tensor = None current_parameter_tensor = None
if current_parameter_tensor is not None: if current_parameter_tensor is not None:
if ( if transpose and (
"c_fc.weight" in key "c_fc.weight" in key
or "c_proj.weight" in key or "c_proj.weight" in key
or "q_attn.weight" in key or "q_attn.weight" in key
or "kv_attn.weight" in key or "kv_attn.weight" in key
or "c_attn.weight" in key
): ):
# Tranpose as we use nn.Linear instead of Conv1D # Tranpose as we use nn.Linear instead of Conv1D
value = value.T value = value.T
if current_parameter_tensor.device == torch.device("meta"): if current_parameter_tensor.device == torch.device("meta"):
# Init qkv # Init qkv
if "attn.weight" in final_key: if "c_attn.weight" in final_key:
module._parameters[param_name] = value.new_empty( module._parameters[param_name] = value.new_empty(
( (
model.transformer.head_size model.transformer.head_size
@ -116,7 +119,7 @@ class FlashSantacoder(FlashCausalLM):
value.shape[1], value.shape[1],
) )
) )
elif "attn.bias" in final_key: elif "c_attn.bias" in final_key:
module._parameters[param_name] = value.new_empty( module._parameters[param_name] = value.new_empty(
( (
model.transformer.head_size model.transformer.head_size
@ -156,3 +159,208 @@ class FlashSantacoder(FlashCausalLM):
return self.tokenizer.decode( return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
) )
class FlashSantacoderSharded(FlashSantacoder):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
if quantize:
raise NotImplementedError(
"FlashSantacoderSharded does not support quantization"
)
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashSantacoderForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
transpose: bool,
):
for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f:
for key in f.keys():
slice_ = f.get_slice(key)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
dim = 1 if transpose and "weight" in param_name else 0
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
)
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
dim = 0 if transpose else 1
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop]
if dim == 0
else slice_[:, start:stop]
)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(key)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor = tensor.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
tensor.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn" in key:
size = tensor.shape[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = tensor[start:stop]
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "c_attn" in key:
# Slice q_tensor by shard
q_tensor = tensor[: -2 * model.transformer.head_size]
block_size = q_tensor.shape[0] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = q_tensor[start:stop]
module._parameters[param_name][
: q_tensor.shape[0]
] = q_tensor
# Kv tensor is copied for every shard
kv_tensor = tensor[-2 * model.transformer.head_size :]
module._parameters[param_name][
q_tensor.shape[0] :
] = kv_tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights()

View File

@ -219,10 +219,11 @@ class GalacticaSharded(Galactica):
filenames, filenames,
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
self.model = model.eval().to(dtype) self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -235,6 +236,7 @@ class GalacticaSharded(Galactica):
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
@ -285,7 +287,7 @@ class GalacticaSharded(Galactica):
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous() tensor = tensor.contiguous().to(dtype)
if quantize: if quantize:
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:

View File

@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM):
filenames, filenames,
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
self.model = model.eval().to(dtype) self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM):
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM):
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous() tensor = tensor.contiguous().to(dtype)
if quantize: if quantize:
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:

View File

@ -80,10 +80,11 @@ class OPTSharded(OPT):
filenames, filenames,
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
self.model = model.eval().to(dtype) self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -96,6 +97,7 @@ class OPTSharded(OPT):
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
@ -146,7 +148,7 @@ class OPTSharded(OPT):
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous() tensor = tensor.contiguous().to(dtype)
if quantize: if quantize:
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:

View File

@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM):
filenames, filenames,
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
self.model = model.eval().to(dtype) self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM):
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM):
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous() tensor = tensor.contiguous().to(dtype)
if quantize: if quantize:
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES: