working model

This commit is contained in:
OlivierDehaene 2023-04-11 20:00:12 +02:00
parent 9541c8f146
commit 622daeb0c8
5 changed files with 101 additions and 115 deletions

View File

@ -83,9 +83,7 @@ def get_model(
if "bigcode" in model_id: if "bigcode" in model_id:
if sharded: if sharded:
if not FLASH_ATTENTION: if not FLASH_ATTENTION:
raise NotImplementedError( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder"))
"sharded is not supported for Santacoder when FLASH_ATTENTION=0"
)
return FlashSantacoderSharded(model_id, revision=revision) return FlashSantacoderSharded(model_id, revision=revision)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder

View File

@ -373,7 +373,7 @@ class LlamaMLP(nn.Module):
x, x,
approximate="tanh" approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"] if act in ["gelu_fast", "gelu_pytorch_tanh"]
else None, else "none",
) )
) )

View File

@ -376,7 +376,12 @@ class FlashMLP(nn.Module):
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh") else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
) )
if process_group is None: if process_group is None:

View File

@ -209,7 +209,7 @@ class FlashMQAttention(torch.nn.Module):
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
self.c_proj = TensorParallelRowLinear( self.c_proj = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group, reduce=True hidden_size, hidden_size, process_group=process_group,
) )
def forward( def forward(
@ -317,7 +317,6 @@ class MLP(nn.Module):
intermediate_size, intermediate_size,
hidden_size, hidden_size,
process_group=process_group, process_group=process_group,
reduce=False,
) )
def forward(self, hidden_states): def forward(self, hidden_states):

View File

@ -64,7 +64,7 @@ class FlashSantacoder(FlashCausalLM):
dtype, dtype,
config.architectures[0].startswith("GPT2") config.architectures[0].startswith("GPT2")
) )
self.model = model.eval().to(device).to(dtype) self.model = model.eval()
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -176,38 +176,37 @@ class FlashSantacoderSharded(FlashSantacoder):
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else: else:
raise NotImplementedError("FlashSantacoder is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
if quantize: if quantize:
raise NotImplementedError("FlashSantacoder does not support quantization") raise NotImplementedError("FlashSantacoderSharded does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = GPT2Config.from_pretrained( config = GPT2Config.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=True, # Needed as the config is not part of Transformers
) )
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights(): with init_empty_weights():
# model = FlashSantacoderForCausalLM(config, self.process_group) model = FlashSantacoderForCausalLM(config, self.process_group)
model = FlashSantacoderForCausalLM(config)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
self.load_weights( self.load_weights(
model, model,
filenames, filenames,
device=device, device=device,
dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
transpose=config.architectures[0].startswith("GPT2"), transpose=config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval().to(dtype) self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -219,67 +218,68 @@ class FlashSantacoderSharded(FlashSantacoder):
model, model,
filenames: List[str], filenames: List[str],
device: torch.device, device: torch.device,
dtype: torch.dtype,
rank: int, rank: int,
world_size: int, world_size: int,
transpose: bool, transpose: bool,
): ):
for file in filenames: for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f: with safe_open(file, framework="pt", device=str(device)) as f:
for name in f.keys(): for key in f.keys():
slice_ = f.get_slice(name) slice_ = f.get_slice(key)
layer_name = ".".join(name.split(".")[:4]) layer_name = ".".join(key.split(".")[:4])
# Fused qkv # Fused qkv
if "q_attn.weight" in name or "kv_attn.weight" in name: if "q_attn.weight" in key or "kv_attn.weight" in key:
final_name = layer_name + ".c_attn.weight" final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in name or "kv_attn.bias" in name: elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_name = layer_name + ".c_attn.bias" final_key = layer_name + ".c_attn.bias"
else: else:
final_name = name final_key = key
module_name, param_name = final_name.rsplit(".", 1) module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name) module = model.get_submodule(module_name)
# if isinstance(module, TensorParallelColumnLinear): if isinstance(module, TensorParallelColumnLinear):
# dim = 1 if transpose and "weight" in param_name else 0 dim = 1 if transpose and "weight" in param_name else 0
# size = slice_.get_shape()[dim] size = slice_.get_shape()[dim]
# block_size = size // world_size block_size = size // world_size
# start = rank * block_size start = rank * block_size
# stop = (rank + 1) * block_size stop = (rank + 1) * block_size
# tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
# elif isinstance(module, TensorParallelRowLinear): elif isinstance(module, TensorParallelRowLinear):
# if param_name == "weight": if param_name == "weight":
# dim = 0 if transpose else 1 dim = 0 if transpose else 1
# size = slice_.get_shape()[dim] size = slice_.get_shape()[dim]
# block_size = size // world_size block_size = size // world_size
# start = rank * block_size start = rank * block_size
# stop = (rank + 1) * block_size stop = (rank + 1) * block_size
# tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
# else: else:
# tensor = slice_[:] tensor = slice_[:]
# # XXX: Hack for Rowlinear to add the bias only once. # XXX: Hack for Rowlinear to add the bias only once.
# if rank != 0: if rank != 0:
# tensor = torch.zeros_like(tensor) tensor = torch.zeros_like(tensor)
# elif isinstance(module, TensorParallelEmbedding): elif isinstance(module, TensorParallelEmbedding):
# size = slice_.get_shape()[0] size = slice_.get_shape()[0]
# block_size = size // world_size block_size = size // world_size
# start = rank * block_size start = rank * block_size
# stop = (rank + 1) * block_size stop = (rank + 1) * block_size
# tensor = slice_[start:stop] tensor = slice_[start:stop]
# elif name == "lm_head.weight" and model.transformer.tp_embeddings: elif key == "lm_head.weight" and model.transformer.tp_embeddings:
# size = slice_.get_shape()[0] size = slice_.get_shape()[0]
# block_size = size // world_size block_size = size // world_size
# start = rank * block_size start = rank * block_size
# stop = (rank + 1) * block_size stop = (rank + 1) * block_size
# tensor = slice_[start:stop] tensor = slice_[start:stop]
# else: else:
try: try:
tensor = slice_[:] tensor = slice_[:]
except: except:
tensor = f.get_tensor(name) tensor = f.get_tensor(key)
tensor = tensor.contiguous() tensor = tensor.contiguous().to(dtype)
try: try:
current_parameter_tensor = module._parameters[param_name] current_parameter_tensor = module._parameters[param_name]
@ -288,18 +288,18 @@ class FlashSantacoderSharded(FlashSantacoder):
if current_parameter_tensor is not None: if current_parameter_tensor is not None:
if transpose and ( if transpose and (
"c_fc.weight" in name "c_fc.weight" in key
or "c_proj.weight" in name or "c_proj.weight" in key
or "q_attn.weight" in name or "q_attn.weight" in key
or "kv_attn.weight" in name or "kv_attn.weight" in key
or "c_attn.weight" in name or "c_attn.weight" in key
): ):
# Tranpose as we use nn.Linear instead of Conv1D # Tranpose as we use nn.Linear instead of Conv1D
tensor = tensor.T tensor = tensor.T
if current_parameter_tensor.device == torch.device("meta"): if current_parameter_tensor.device == torch.device("meta"):
# Init qkv # Init qkv
if "c_attn.weight" in final_name: if "c_attn.weight" in final_key:
module._parameters[param_name] = tensor.new_empty( module._parameters[param_name] = tensor.new_empty(
( (
model.transformer.head_size model.transformer.head_size
@ -307,7 +307,7 @@ class FlashSantacoderSharded(FlashSantacoder):
tensor.shape[1], tensor.shape[1],
) )
) )
elif "c_attn.bias" in final_name: elif "c_attn.bias" in final_key:
module._parameters[param_name] = tensor.new_empty( module._parameters[param_name] = tensor.new_empty(
( (
model.transformer.head_size model.transformer.head_size
@ -316,63 +316,47 @@ class FlashSantacoderSharded(FlashSantacoder):
) )
# Copy to correct slice # Copy to correct slice
# if "q_attn" in name: if "q_attn" in key:
# size = tensor.shape[0] size = tensor.shape[0]
# block_size = size // world_size block_size = size // world_size
# start = rank * block_size start = rank * block_size
# stop = (rank + 1) * block_size stop = (rank + 1) * block_size
# tensor = tensor[start:stop] tensor = tensor[start:stop]
# module._parameters[param_name][: tensor.shape[0]] = tensor
# elif "kv_attn.weight" in name:
# module._parameters[param_name][
# model.transformer.head_size
# * model.transformer.num_heads :
# ] = tensor
# elif "kv_attn.bias" in name:
# module._parameters[param_name][
# model.transformer.head_size
# * model.transformer.num_heads :
# ] = tensor
# elif "c_attn" in name:
# q_tensor = tensor[: -2 * model.transformer.head_size]
# kv_tensor = tensor[-2 * model.transformer.head_size :]
# from loguru import logger
#
# block_size = q_tensor.shape[0] // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# q_tensor = q_tensor[start:stop]
# logger.error(q_tensor.shape)
# logger.error(kv_tensor.shape)
# module._parameters[param_name][
# : q_tensor.shape[0]
# ] = q_tensor
# module._parameters[param_name][
# q_tensor.shape[0] :
# ] = kv_tensor
from loguru import logger
if "q_attn.weight" in name:
logger.error(f"q - {module._parameters[param_name][: tensor.shape[0]].shape} - {tensor.shape}")
module._parameters[param_name][: tensor.shape[0]] = tensor module._parameters[param_name][: tensor.shape[0]] = tensor
elif "q_attn.bias" in name: elif "kv_attn.weight" in key:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in name:
logger.error(f"kv - {module._parameters[param_name][model.transformer.head_size * model.transformer.num_heads:].shape} - {tensor.shape}")
module._parameters[param_name][ module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads: model.transformer.head_size
* model.transformer.num_heads :
] = tensor ] = tensor
elif "kv_attn.bias" in name: elif "kv_attn.bias" in key:
module._parameters[param_name][ module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads: model.transformer.head_size
* model.transformer.num_heads :
] = tensor ] = tensor
elif "c_attn" in key:
# Slice q_tensor by shard
q_tensor = tensor[: -2 * model.transformer.head_size]
block_size = q_tensor.shape[0] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = q_tensor[start:stop]
module._parameters[param_name][
: q_tensor.shape[0]
] = q_tensor
# Kv tensor is copied for every shard
kv_tensor = tensor[-2 * model.transformer.head_size :]
module._parameters[param_name][
q_tensor.shape[0] :
] = kv_tensor
else: else:
if current_parameter_tensor.shape != tensor.shape: if current_parameter_tensor.shape != tensor.shape:
raise ValueError( raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
) )
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
torch.cuda.empty_cache() torch.cuda.empty_cache()