diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py new file mode 100644 index 00000000..5b89d275 --- /dev/null +++ b/server/text_generation/models/__init__.py @@ -0,0 +1,91 @@ +import torch + +from transformers import AutoConfig +from typing import Optional + +from text_generation.models.model import Model +from text_generation.models.causal_lm import CausalLM +from text_generation.models.bloom import BLOOM, BLOOMSharded +from text_generation.models.seq2seq_lm import Seq2SeqLM +from text_generation.models.galactica import Galactica, GalacticaSharded +from text_generation.models.santacoder import SantaCoder +from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded +from text_generation.models.opt import OPT, OPTSharded +from text_generation.models.t5 import T5Sharded + +__all__ = [ + "Model", + "BLOOM", + "BLOOMSharded", + "CausalLM", + "Galactica", + "GalacticaSharded", + "GPTNeox", + "GPTNeoxSharded", + "Seq2SeqLM", + "Galactica", + "GalacticaSharded", + "SantaCoder", + "GPTNeox", + "GPTNeoxSharded", + "OPT", + "OPTSharded", + "T5Sharded", + "get_model", +] + +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True + +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = True + +# Disable gradients +torch.set_grad_enabled(False) + + +def get_model( + model_id: str, revision: Optional[str], sharded: bool, quantize: bool +) -> Model: + if model_id.startswith("facebook/galactica"): + if sharded: + return GalacticaSharded(model_id, revision, quantize=quantize) + else: + return Galactica(model_id, revision, quantize=quantize) + + if "santacoder" in model_id: + return SantaCoder(model_id, revision, quantize) + + config = AutoConfig.from_pretrained(model_id, revision=revision) + + if config.model_type == "bloom": + if sharded: + return BLOOMSharded(model_id, revision, quantize=quantize) + else: + return BLOOM(model_id, revision, quantize=quantize) + + if config.model_type == "gpt_neox": + if sharded: + return GPTNeoxSharded(model_id, revision, quantize=quantize) + else: + return GPTNeox(model_id, revision, quantize=quantize) + + if config.model_type == "t5": + if sharded: + return T5Sharded(model_id, revision, quantize=quantize) + else: + return Seq2SeqLM(model_id, revision, quantize=quantize) + + if config.model_type == "opt": + if sharded: + return OPTSharded(model_id, revision, quantize=quantize) + else: + return OPT(model_id, revision, quantize=quantize) + + if sharded: + raise ValueError("sharded is not supported for AutoModel") + try: + return CausalLM(model_id, revision, quantize=quantize) + except Exception: + return Seq2SeqLM(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index ce3895ca..5a6a9c0d 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -122,18 +122,11 @@ class BLOOMSharded(BLOOM): slice_ = f.get_slice(name) if isinstance(module, TensorParallelColumnLinear): - if param_name == "weight": - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] elif isinstance(module, TensorParallelRowLinear): if param_name == "weight": size = slice_.get_shape()[1] diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index f7fbb2ad..0022a50d 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation_server.models import CausalLM -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.utils import ( +from text_generation.models import CausalLMBatch +from text_generation.pb import generate_pb2 +from text_generation.models.opt import OPT, OPTSharded +from text_generation.utils import ( NextTokenChooser, StoppingCriteria, initialize_torch_distributed, @@ -158,7 +158,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ) -class Galactica(CausalLM): +class Galactica(OPT): @property def batch_type(self) -> Type[CausalLMBatch]: return GalacticaCausalLMBatch @@ -184,7 +184,7 @@ class Galactica(CausalLM): return outputs.logits, outputs.past_key_values -class GalacticaSharded(Galactica): +class GalacticaSharded(OPTSharded): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): @@ -253,18 +253,11 @@ class GalacticaSharded(Galactica): slice_ = f.get_slice(name) if isinstance(module, TensorParallelColumnLinear): - if param_name == "weight": - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] elif isinstance(module, TensorParallelRowLinear): if param_name == "weight": size = slice_.get_shape()[1] diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py new file mode 100644 index 00000000..6f437957 --- /dev/null +++ b/server/text_generation_server/models/opt.py @@ -0,0 +1,233 @@ +import torch +import torch.distributed + +from typing import List, Optional, Tuple + +from accelerate import init_empty_weights +from safetensors import safe_open +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + AutoConfig, +) +from transformers.models.opt.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) + +from text_generation.models import CausalLM +from text_generation.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, +) + +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params +except Exception as e: + HAS_BITS_AND_BYTES = False + + +class OPT(CausalLM): + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """Overwrite forward to ignore position_ids""" + + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values + + +class OPTSharded(OPT): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + tokenizer.pad_token_id = config.pad_token_id + + # Only download weights for small models + if self.master: + download_weights(model_id, revision=revision, extension=".safetensors") + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + if not filenames: + raise ValueError("No safetensors weights found") + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + if name == "lm_head.weight": + continue + + full_name = f"model.{name}" + + module_name, param_name = full_name.rsplit(".", 1) + module = model.get_submodule(module_name) + current_tensor = parameters[full_name] + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + tensor = slice_[:] + + if current_tensor.shape != tensor.shape: + raise ValueError( + f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + if ( + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" + ): + tensor = Int8Params( + tensor, + has_fp16_weights=False, + requires_grad=False, + ).to(device) + state = bnb.MatmulLtState() + state.threshold = 6.0 + state.has_fp16_weights = False + state.memory_efficient_backward = False + state.use_pool = True + state.CB = tensor.CB + state.SCB = tensor.SCB + tensor.CB = None + tensor.SCB = None + + def replace_linear(state): + def linear(input, weight, bias): + out = bnb.matmul( + input, + weight, + state=state, + threshold=state.threshold, + bias=bias, + ) + + if state.CB is not None: + # we converted 8-bit row major to turing/ampere format + # in the first inference pass + # we no longer need the row-major weight + del state.CB + weight.data = state.CxB + + return out + + return linear + + module.linear = replace_linear(state) + + else: + tensor = tensor.to(device) + + module._parameters[param_name] = tensor + if full_name == "model.decoder.embed_tokens.weight": + model.lm_head._parameters["weight"] = tensor + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + # Logits are sharded, so we need to gather them + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) + + return logits, outputs.past_key_values