working model

This commit is contained in:
OlivierDehaene 2023-04-11 20:00:12 +02:00
parent 9541c8f146
commit 622daeb0c8
5 changed files with 101 additions and 115 deletions

View File

@ -83,9 +83,7 @@ def get_model(
if "bigcode" in model_id:
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
"sharded is not supported for Santacoder when FLASH_ATTENTION=0"
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder"))
return FlashSantacoderSharded(model_id, revision=revision)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder

View File

@ -373,7 +373,7 @@ class LlamaMLP(nn.Module):
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else None,
else "none",
)
)

View File

@ -376,7 +376,12 @@ class FlashMLP(nn.Module):
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
)
if process_group is None:

View File

@ -209,7 +209,7 @@ class FlashMQAttention(torch.nn.Module):
self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
self.c_proj = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group, reduce=True
hidden_size, hidden_size, process_group=process_group,
)
def forward(
@ -317,7 +317,6 @@ class MLP(nn.Module):
intermediate_size,
hidden_size,
process_group=process_group,
reduce=False,
)
def forward(self, hidden_states):

View File

@ -64,7 +64,7 @@ class FlashSantacoder(FlashCausalLM):
dtype,
config.architectures[0].startswith("GPT2")
)
self.model = model.eval().to(device).to(dtype)
self.model = model.eval()
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
@ -176,38 +176,37 @@ class FlashSantacoderSharded(FlashSantacoder):
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
if quantize:
raise NotImplementedError("FlashSantacoder does not support quantization")
raise NotImplementedError("FlashSantacoderSharded does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True, # Needed as the config is not part of Transformers
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
# model = FlashSantacoderForCausalLM(config, self.process_group)
model = FlashSantacoderForCausalLM(config)
model = FlashSantacoderForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
self.model = model.eval().to(dtype)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
@ -219,67 +218,68 @@ class FlashSantacoderSharded(FlashSantacoder):
model,
filenames: List[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
transpose: bool,
):
for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f:
for name in f.keys():
slice_ = f.get_slice(name)
for key in f.keys():
slice_ = f.get_slice(key)
layer_name = ".".join(name.split(".")[:4])
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in name or "kv_attn.weight" in name:
final_name = layer_name + ".c_attn.weight"
elif "q_attn.bias" in name or "kv_attn.bias" in name:
final_name = layer_name + ".c_attn.bias"
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_name = name
final_key = key
module_name, param_name = final_name.rsplit(".", 1)
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
# if isinstance(module, TensorParallelColumnLinear):
# dim = 1 if transpose and "weight" in param_name else 0
# size = slice_.get_shape()[dim]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
# elif isinstance(module, TensorParallelRowLinear):
# if param_name == "weight":
# dim = 0 if transpose else 1
# size = slice_.get_shape()[dim]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
# else:
# tensor = slice_[:]
# # XXX: Hack for Rowlinear to add the bias only once.
# if rank != 0:
# tensor = torch.zeros_like(tensor)
# elif isinstance(module, TensorParallelEmbedding):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif name == "lm_head.weight" and model.transformer.tp_embeddings:
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# else:
if isinstance(module, TensorParallelColumnLinear):
dim = 1 if transpose and "weight" in param_name else 0
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
dim = 0 if transpose else 1
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
tensor = f.get_tensor(key)
tensor = tensor.contiguous()
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
@ -288,18 +288,18 @@ class FlashSantacoderSharded(FlashSantacoder):
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in name
or "c_proj.weight" in name
or "q_attn.weight" in name
or "kv_attn.weight" in name
or "c_attn.weight" in name
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor = tensor.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_name:
if "c_attn.weight" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
@ -307,7 +307,7 @@ class FlashSantacoderSharded(FlashSantacoder):
tensor.shape[1],
)
)
elif "c_attn.bias" in final_name:
elif "c_attn.bias" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
@ -316,63 +316,47 @@ class FlashSantacoderSharded(FlashSantacoder):
)
# Copy to correct slice
# if "q_attn" in name:
# size = tensor.shape[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = tensor[start:stop]
# module._parameters[param_name][: tensor.shape[0]] = tensor
# elif "kv_attn.weight" in name:
# module._parameters[param_name][
# model.transformer.head_size
# * model.transformer.num_heads :
# ] = tensor
# elif "kv_attn.bias" in name:
# module._parameters[param_name][
# model.transformer.head_size
# * model.transformer.num_heads :
# ] = tensor
# elif "c_attn" in name:
# q_tensor = tensor[: -2 * model.transformer.head_size]
# kv_tensor = tensor[-2 * model.transformer.head_size :]
# from loguru import logger
#
# block_size = q_tensor.shape[0] // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# q_tensor = q_tensor[start:stop]
# logger.error(q_tensor.shape)
# logger.error(kv_tensor.shape)
# module._parameters[param_name][
# : q_tensor.shape[0]
# ] = q_tensor
# module._parameters[param_name][
# q_tensor.shape[0] :
# ] = kv_tensor
from loguru import logger
if "q_attn.weight" in name:
logger.error(f"q - {module._parameters[param_name][: tensor.shape[0]].shape} - {tensor.shape}")
if "q_attn" in key:
size = tensor.shape[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = tensor[start:stop]
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "q_attn.bias" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in name:
logger.error(f"kv - {module._parameters[param_name][model.transformer.head_size * model.transformer.num_heads:].shape} - {tensor.shape}")
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads:
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "kv_attn.bias" in name:
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads:
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "c_attn" in key:
# Slice q_tensor by shard
q_tensor = tensor[: -2 * model.transformer.head_size]
block_size = q_tensor.shape[0] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = q_tensor[start:stop]
module._parameters[param_name][
: q_tensor.shape[0]
] = q_tensor
# Kv tensor is copied for every shard
kv_tensor = tensor[-2 * model.transformer.head_size :]
module._parameters[param_name][
q_tensor.shape[0] :
] = kv_tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()