mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
# What does this PR do? Reworked the loading logic. Idea is to use cleaner loading code: - Remove need for `no_init_weights` - Remove all weird `bnb_linear` and `load_weights` and `post_load_weights`. New code layout: - New class `Weights` in charge of handling loading the weights from multiple files into appropiate tensors (potentially sharded) - TP layers now are "shells", they contain the code to know what kind of sharding we need + eventual `all_reduce`. They do not inherit from linear, but they contain some kind of Linear instead - the contained linear can be either FastLinear, BnbLinear or GPTq Linear next. - All modeling code is explictly made for sharding, process group is just no-ops for non sharded code (removes a lot of test cases)  --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net> Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co> Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
72 lines
1.9 KiB
Python
72 lines
1.9 KiB
Python
import os
|
|
import torch
|
|
|
|
from datetime import timedelta
|
|
|
|
|
|
class FakeBarrier:
|
|
def wait(self):
|
|
pass
|
|
|
|
|
|
class FakeGroup:
|
|
def __init__(self, rank, size):
|
|
self._rank = rank
|
|
self._size = size
|
|
|
|
def allreduce(self, *args, **kwargs):
|
|
return FakeBarrier()
|
|
|
|
def allgather(self, inputs, local_tensor, **kwargs):
|
|
assert (
|
|
len(inputs[0]) == len(local_tensor) == 1
|
|
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
|
|
for input_ in inputs:
|
|
input_[0].data = local_tensor[0].data
|
|
return FakeBarrier()
|
|
|
|
def barrier(self, *args, **kwargs):
|
|
return FakeBarrier()
|
|
|
|
def size(self):
|
|
return self._size
|
|
|
|
def rank(self):
|
|
return self._rank
|
|
|
|
|
|
def initialize_torch_distributed():
|
|
rank = int(os.getenv("RANK", "0"))
|
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
|
|
if torch.cuda.is_available():
|
|
from torch.distributed import ProcessGroupNCCL
|
|
|
|
# Set the device id.
|
|
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
|
device = rank % torch.cuda.device_count()
|
|
torch.cuda.set_device(device)
|
|
backend = "nccl"
|
|
options = ProcessGroupNCCL.Options()
|
|
options.is_high_priority_stream = True
|
|
options._timeout = timedelta(seconds=60)
|
|
else:
|
|
backend = "gloo"
|
|
options = None
|
|
|
|
if world_size == 1:
|
|
return FakeGroup(rank, world_size), rank, world_size
|
|
else:
|
|
if os.getenv("DEBUG", None) == "1":
|
|
return FakeGroup(rank, world_size), rank, world_size
|
|
# Call the init process.
|
|
torch.distributed.init_process_group(
|
|
backend=backend,
|
|
world_size=world_size,
|
|
rank=rank,
|
|
timeout=timedelta(seconds=60),
|
|
pg_options=options,
|
|
)
|
|
|
|
return torch.distributed.group.WORLD, rank, world_size
|