diff --git a/README.md b/README.md index 67940ae8..d54a092a 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - [SantaCoder](https://huggingface.co/bigcode/santacoder) -- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13` +- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13` Other models are supported on a best effort basis using: diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index d901cae3..a8f7f365 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox): start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] - elif name == "embed_out.weight": + elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size @@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor, @@ -229,16 +229,24 @@ class GPTNeoxSharded(GPTNeox): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) + if self.model.gpt_neox.tp_embeddings: + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) + # Logits are sharded, so we need to gather them + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather( + logits, outputs.logits, group=self.process_group + ) + logits = torch.cat(logits, dim=2) - return logits, outputs.past_key_values + return logits, outputs.past_key_values + # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard + else: + return super(GPTNeoxSharded, self).forward( + input_ids, attention_mask, position_ids, past_key_values + ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 62d60635..3b07ef3f 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -91,7 +91,7 @@ class NextTokenChooser: top_p=pb.top_p, do_sample=pb.do_sample, seed=pb.seed, - device=str(device), + device=device, )