From 3149317fa1beeb65bdd1460f2f17c2cb6160aee1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 1 Feb 2023 11:48:18 +0100 Subject: [PATCH] formatting --- server/text_generation/models/gpt_neox.py | 38 +++++++++++++---------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index 200eec6d..a8f7f365 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -33,7 +33,7 @@ except Exception as e: class GPTNeox(CausalLM): def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: """Overwrite forward to ignore position_ids""" @@ -49,7 +49,7 @@ class GPTNeox(CausalLM): class GPTNeoxSharded(GPTNeox): def __init__( - self, model_name: str, revision: Optional[str] = None, quantize: bool = False + self, model_name: str, revision: Optional[str] = None, quantize: bool = False ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -101,17 +101,17 @@ class GPTNeoxSharded(GPTNeox): @staticmethod def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) @@ -158,8 +158,8 @@ class GPTNeoxSharded(GPTNeox): tensor = f.get_tensor(name) if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape ): raise ValueError( f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" @@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor, @@ -227,7 +227,7 @@ class GPTNeoxSharded(GPTNeox): module._buffers[param_name] = tensor def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): if self.model.gpt_neox.tp_embeddings: outputs = self.model.forward( @@ -239,10 +239,14 @@ class GPTNeoxSharded(GPTNeox): # Logits are sharded, so we need to gather them logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + torch.distributed.all_gather( + logits, outputs.logits, group=self.process_group + ) logits = torch.cat(logits, dim=2) return logits, outputs.past_key_values # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard else: - return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values) + return super(GPTNeoxSharded, self).forward( + input_ids, attention_mask, position_ids, past_key_values + )