feat: initial weight load

This commit is contained in:
drbh 2024-01-23 01:37:09 +00:00
parent 98e5faff9d
commit bcf145733c
3 changed files with 272 additions and 1 deletions

View File

@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.mamba import Mamba
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -161,7 +162,25 @@ def get_model(
if speculate > 0: if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.") logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"] model_type = config_dict.get("model_type", None)
if model_type is None:
# TODO: fix how we determine model type for Mamba
if "ssm_cfg" in config_dict:
# *only happens in Mamba case
model_type = "ssm"
else:
raise RuntimeError(
f"Could not determine model type for {model_id} revision {revision}"
)
if model_type == "ssm":
return Mamba(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -0,0 +1,190 @@
import torch
import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLinear,
FastRMSNorm,
)
class MambaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
n_positions=2048,
n_embd=2560,
n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
no_bias=False,
rms_norm_eps=1e-8,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
self.rms_norm_eps = rms_norm_eps
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
# TODO: adjust how weights are loaded
# conv1d 768*2, 768*2, 4
self.conv1 = nn.Conv1d(768, 768, 4)
# add weight and bias to conv1
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight").transpose(0, 1))
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
# TODO: load weights in correctly for other operations
self.dt_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.dt_proj",
weights=weights,
bias=True,
)
self.in_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.in_proj",
weights=weights,
bias=False,
)
self.x_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.x_proj",
weights=weights,
bias=False,
)
self.A_log = nn.Parameter(torch.randn(config.n_head, config.n_head, config.rotary_dim))
self.D = nn.Parameter(torch.randn(config.n_head, config.rotary_dim))
def forward(
self,
hidden_states,
past_kv_cache,
attention_mask=None,
):
hidden_states_in_proj = self.in_proj(hidden_states)
hidden_states_and_residual = torch.chunk(hidden_states_in_proj, 2, dim=-1)
hs, res = hidden_states_and_residual[0], hidden_states_and_residual[1]
import ipdb; ipdb.set_trace()
class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.mixer = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.layer_norm = FastLinear.load(
config=config,
prefix=f"{layer_id}.norm",
weights=weights,
bias=False,
)
def forward(
self,
hidden_states,
kv_cache,
attention_mask,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
hidden_states = residual + attn_outputs
return hidden_states, residual
class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="backbone.embedding", weights=weights
)
self.blocks = nn.ModuleList(
[ResidualBlock(f"backbone.layers.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
)
self.norm_f = FastRMSNorm.load(
prefix="backbone.norm_f",
weights=weights,
eps=config.rms_norm_eps
)
print("🌈 model init done")
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
past_key_values[index] = new_key_values
hidden_states = self.norm_f(hidden_states)
return hidden_states, past_key_values
class MambaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = MambaModel(config, weights)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache
)
print("🌈 model output done")

View File

@ -0,0 +1,62 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.mamba_modeling import MambaConfig, MambaForCausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Mamba(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/gpt-neox-20b",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = MambaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = MambaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)