From bcf145733c1ac8944a14feb23cd3a7ce8ffe5df7 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 23 Jan 2024 01:37:09 +0000 Subject: [PATCH] feat: initial weight load --- .../text_generation_server/models/__init__.py | 21 +- .../models/custom_modeling/mamba_modeling.py | 190 ++++++++++++++++++ server/text_generation_server/models/mamba.py | 62 ++++++ 3 files changed, 272 insertions(+), 1 deletion(-) create mode 100644 server/text_generation_server/models/custom_modeling/mamba_modeling.py create mode 100644 server/text_generation_server/models/mamba.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 39d1d58ec..f22456df7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded +from text_generation_server.models.mamba import Mamba # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -161,7 +162,25 @@ def get_model( if speculate > 0: logger.info(f"Using speculation {method} with {speculate} input ids.") - model_type = config_dict["model_type"] + model_type = config_dict.get("model_type", None) + if model_type is None: + # TODO: fix how we determine model type for Mamba + if "ssm_cfg" in config_dict: + # *only happens in Mamba case + model_type = "ssm" + else: + raise RuntimeError( + f"Could not determine model type for {model_id} revision {revision}" + ) + + if model_type == "ssm": + return Mamba( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "gpt_bigcode": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py new file mode 100644 index 000000000..d98352deb --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -0,0 +1,190 @@ +import torch +import torch.distributed + +import math +from torch import nn +from typing import Optional, List, Tuple, Any +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelHead, + FastLinear, + FastRMSNorm, +) + +class MambaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=51200, + n_positions=2048, + n_embd=2560, + n_layer=32, + n_inner=None, + n_head=32, + rotary_dim=32, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + pad_vocab_size_multiple=64, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + no_bias=False, + rms_norm_eps=1e-8, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_inner = n_inner + self.n_head = n_head + self.rotary_dim = rotary_dim + + self.layer_norm_epsilon = layer_norm_epsilon + self.tie_word_embeddings = tie_word_embeddings + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.no_bias = no_bias + self.rms_norm_eps = rms_norm_eps + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +class MambaBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + # TODO: adjust how weights are loaded + + # conv1d 768*2, 768*2, 4 + self.conv1 = nn.Conv1d(768, 768, 4) + # add weight and bias to conv1 + self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight").transpose(0, 1)) + self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias")) + + # TODO: load weights in correctly for other operations + self.dt_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.dt_proj", + weights=weights, + bias=True, + ) + self.in_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.in_proj", + weights=weights, + bias=False, + ) + self.x_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.x_proj", + weights=weights, + bias=False, + ) + self.A_log = nn.Parameter(torch.randn(config.n_head, config.n_head, config.rotary_dim)) + self.D = nn.Parameter(torch.randn(config.n_head, config.rotary_dim)) + + def forward( + self, + hidden_states, + past_kv_cache, + attention_mask=None, + ): + hidden_states_in_proj = self.in_proj(hidden_states) + hidden_states_and_residual = torch.chunk(hidden_states_in_proj, 2, dim=-1) + + hs, res = hidden_states_and_residual[0], hidden_states_and_residual[1] + + import ipdb; ipdb.set_trace() + +class ResidualBlock(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.layer_id = layer_id + self.mixer = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) + self.layer_norm = FastLinear.load( + config=config, + prefix=f"{layer_id}.norm", + weights=weights, + bias=False, + ) + + def forward( + self, + hidden_states, + kv_cache, + attention_mask, + ): + residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask) + hidden_states = residual + attn_outputs + return hidden_states, residual + +class MambaModel(nn.Module): + def __init__(self, config, weights): + super().__init__() + self.tp_rank = weights.process_group.rank() + self.tp_world_size = weights.process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="backbone.embedding", weights=weights + ) + self.blocks = nn.ModuleList( + [ResidualBlock(f"backbone.layers.{layer_id}", config, weights) for layer_id in range(config.n_layer)] + ) + self.norm_f = FastRMSNorm.load( + prefix="backbone.norm_f", + weights=weights, + eps=config.rms_norm_eps + ) + print("🌈 model init done") + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + hidden_states = self.embed_tokens(input_ids) + seq_len = hidden_states.shape[1] + mask = None if seq_len <= 1 else attention_mask + + past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values + + for index, block in enumerate(self.blocks): + hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask) + past_key_values[index] = new_key_values + + hidden_states = self.norm_f(hidden_states) + return hidden_states, past_key_values + +class MambaForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.model = MambaModel(config, weights) + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + model_output = self.model( + input_ids, past_key_values, attention_mask, return_dict, use_cache + ) + print("🌈 model output done") \ No newline at end of file diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py new file mode 100644 index 000000000..13b2a7d68 --- /dev/null +++ b/server/text_generation_server/models/mamba.py @@ -0,0 +1,62 @@ +import torch +import torch.distributed + +from transformers import AutoConfig, AutoTokenizer +from typing import Optional, List, Tuple + +from text_generation_server.models import CausalLM +from text_generation_server.models.custom_modeling.mamba_modeling import MambaConfig, MambaForCausalLM +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +class Mamba(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, _rank, _world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/gpt-neox-20b", + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + config = MambaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + tokenizer.bos_token_id = config.bos_token_id + tokenizer.eos_token_id = config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + + config.quantize = quantize + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + model = MambaForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) \ No newline at end of file