formatting

This commit is contained in:
OlivierDehaene 2023-02-01 11:48:18 +01:00
parent 1d0fa38cb8
commit 3149317fa1

View File

@ -33,7 +33,7 @@ except Exception as e:
class GPTNeox(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
class GPTNeoxSharded(GPTNeox):
def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@ -101,17 +101,17 @@ class GPTNeoxSharded(GPTNeox):
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
@ -158,8 +158,8 @@ class GPTNeoxSharded(GPTNeox):
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
@ -227,7 +227,7 @@ class GPTNeoxSharded(GPTNeox):
module._buffers[param_name] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward(
@ -239,10 +239,14 @@ class GPTNeoxSharded(GPTNeox):
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)