text-generation-inference/server/text_generation_server/models/flash_rw.py
Nicolas Patry abd58ff82c
feat(server): Rework model loading (#344)
# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
2023-06-08 14:51:52 +02:00

67 lines
1.9 KiB
Python

import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
config.quantize = quantize
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)