mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
formatting
This commit is contained in:
parent
1d0fa38cb8
commit
3149317fa1
@ -33,7 +33,7 @@ except Exception as e:
|
|||||||
|
|
||||||
class GPTNeox(CausalLM):
|
class GPTNeox(CausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
"""Overwrite forward to ignore position_ids"""
|
"""Overwrite forward to ignore position_ids"""
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
|
|||||||
|
|
||||||
class GPTNeoxSharded(GPTNeox):
|
class GPTNeoxSharded(GPTNeox):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -101,17 +101,17 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
):
|
):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
@ -158,8 +158,8 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
tensor = f.get_tensor(name)
|
tensor = f.get_tensor(name)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
current_parameter_tensor is not None
|
current_parameter_tensor is not None
|
||||||
and current_parameter_tensor.shape != tensor.shape
|
and current_parameter_tensor.shape != tensor.shape
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||||
@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
type(module)
|
type(module)
|
||||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||||
and param_name == "weight"
|
and param_name == "weight"
|
||||||
):
|
):
|
||||||
tensor = Int8Params(
|
tensor = Int8Params(
|
||||||
tensor,
|
tensor,
|
||||||
@ -227,7 +227,7 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
):
|
):
|
||||||
if self.model.gpt_neox.tp_embeddings:
|
if self.model.gpt_neox.tp_embeddings:
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
@ -239,10 +239,14 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
|
|
||||||
# Logits are sharded, so we need to gather them
|
# Logits are sharded, so we need to gather them
|
||||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
torch.distributed.all_gather(
|
||||||
|
logits, outputs.logits, group=self.process_group
|
||||||
|
)
|
||||||
logits = torch.cat(logits, dim=2)
|
logits = torch.cat(logits, dim=2)
|
||||||
|
|
||||||
return logits, outputs.past_key_values
|
return logits, outputs.past_key_values
|
||||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||||
else:
|
else:
|
||||||
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)
|
return super(GPTNeoxSharded, self).forward(
|
||||||
|
input_ids, attention_mask, position_ids, past_key_values
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user