mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
formatting
This commit is contained in:
parent
1d0fa38cb8
commit
3149317fa1
@ -33,7 +33,7 @@ except Exception as e:
|
||||
|
||||
class GPTNeox(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Overwrite forward to ignore position_ids"""
|
||||
|
||||
@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
|
||||
|
||||
class GPTNeoxSharded(GPTNeox):
|
||||
def __init__(
|
||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
@ -101,17 +101,17 @@ class GPTNeoxSharded(GPTNeox):
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
@ -158,8 +158,8 @@ class GPTNeoxSharded(GPTNeox):
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
@ -227,7 +227,7 @@ class GPTNeoxSharded(GPTNeox):
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
if self.model.gpt_neox.tp_embeddings:
|
||||
outputs = self.model.forward(
|
||||
@ -239,10 +239,14 @@ class GPTNeoxSharded(GPTNeox):
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
torch.distributed.all_gather(
|
||||
logits, outputs.logits, group=self.process_group
|
||||
)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||
else:
|
||||
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)
|
||||
return super(GPTNeoxSharded, self).forward(
|
||||
input_ids, attention_mask, position_ids, past_key_values
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user