diff --git a/README.md b/README.md index 67940ae8..d54a092a 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - [SantaCoder](https://huggingface.co/bigcode/santacoder) -- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13` +- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13` Other models are supported on a best effort basis using: diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index d901cae3..200eec6d 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -33,7 +33,7 @@ except Exception as e: class GPTNeox(CausalLM): def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: """Overwrite forward to ignore position_ids""" @@ -49,7 +49,7 @@ class GPTNeox(CausalLM): class GPTNeoxSharded(GPTNeox): def __init__( - self, model_name: str, revision: Optional[str] = None, quantize: bool = False + self, model_name: str, revision: Optional[str] = None, quantize: bool = False ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -101,17 +101,17 @@ class GPTNeoxSharded(GPTNeox): @staticmethod def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) @@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox): start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] - elif name == "embed_out.weight": + elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size @@ -158,8 +158,8 @@ class GPTNeoxSharded(GPTNeox): tensor = f.get_tensor(name) if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape ): raise ValueError( f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" @@ -227,18 +227,22 @@ class GPTNeoxSharded(GPTNeox): module._buffers[param_name] = tensor def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) + if self.model.gpt_neox.tp_embeddings: + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) + # Logits are sharded, so we need to gather them + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) - return logits, outputs.past_key_values + return logits, outputs.past_key_values + # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard + else: + return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 62d60635..3b07ef3f 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -91,7 +91,7 @@ class NextTokenChooser: top_p=pb.top_p, do_sample=pb.do_sample, seed=pb.seed, - device=str(device), + device=device, )