mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat: initial weight load
This commit is contained in:
parent
98e5faff9d
commit
bcf145733c
@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded
|
|||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
|
from text_generation_server.models.mamba import Mamba
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
@ -161,7 +162,25 @@ def get_model(
|
|||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict.get("model_type", None)
|
||||||
|
if model_type is None:
|
||||||
|
# TODO: fix how we determine model type for Mamba
|
||||||
|
if "ssm_cfg" in config_dict:
|
||||||
|
# *only happens in Mamba case
|
||||||
|
model_type = "ssm"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not determine model type for {model_id} revision {revision}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "ssm":
|
||||||
|
return Mamba(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -0,0 +1,190 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
import math
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, List, Tuple, Any
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
|
FastLinear,
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MambaConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=51200,
|
||||||
|
n_positions=2048,
|
||||||
|
n_embd=2560,
|
||||||
|
n_layer=32,
|
||||||
|
n_inner=None,
|
||||||
|
n_head=32,
|
||||||
|
rotary_dim=32,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
pad_vocab_size_multiple=64,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
no_bias=False,
|
||||||
|
rms_norm_eps=1e-8,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.n_positions = n_positions
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_inner = n_inner
|
||||||
|
self.n_head = n_head
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.no_bias = no_bias
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MambaBlock(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
# TODO: adjust how weights are loaded
|
||||||
|
|
||||||
|
# conv1d 768*2, 768*2, 4
|
||||||
|
self.conv1 = nn.Conv1d(768, 768, 4)
|
||||||
|
# add weight and bias to conv1
|
||||||
|
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight").transpose(0, 1))
|
||||||
|
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
||||||
|
|
||||||
|
# TODO: load weights in correctly for other operations
|
||||||
|
self.dt_proj = TensorParallelColumnLinear.load(
|
||||||
|
config=config,
|
||||||
|
prefix=f"{prefix}.dt_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.in_proj = TensorParallelColumnLinear.load(
|
||||||
|
config=config,
|
||||||
|
prefix=f"{prefix}.in_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.x_proj = TensorParallelColumnLinear.load(
|
||||||
|
config=config,
|
||||||
|
prefix=f"{prefix}.x_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.A_log = nn.Parameter(torch.randn(config.n_head, config.n_head, config.rotary_dim))
|
||||||
|
self.D = nn.Parameter(torch.randn(config.n_head, config.rotary_dim))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
past_kv_cache,
|
||||||
|
attention_mask=None,
|
||||||
|
):
|
||||||
|
hidden_states_in_proj = self.in_proj(hidden_states)
|
||||||
|
hidden_states_and_residual = torch.chunk(hidden_states_in_proj, 2, dim=-1)
|
||||||
|
|
||||||
|
hs, res = hidden_states_and_residual[0], hidden_states_and_residual[1]
|
||||||
|
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.mixer = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||||
|
self.layer_norm = FastLinear.load(
|
||||||
|
config=config,
|
||||||
|
prefix=f"{layer_id}.norm",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
kv_cache,
|
||||||
|
attention_mask,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
|
||||||
|
hidden_states = residual + attn_outputs
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
class MambaModel(nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_rank = weights.process_group.rank()
|
||||||
|
self.tp_world_size = weights.process_group.size()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="backbone.embedding", weights=weights
|
||||||
|
)
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[ResidualBlock(f"backbone.layers.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
|
||||||
|
)
|
||||||
|
self.norm_f = FastRMSNorm.load(
|
||||||
|
prefix="backbone.norm_f",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
print("🌈 model init done")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.ByteTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
seq_len = hidden_states.shape[1]
|
||||||
|
mask = None if seq_len <= 1 else attention_mask
|
||||||
|
|
||||||
|
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
|
||||||
|
|
||||||
|
for index, block in enumerate(self.blocks):
|
||||||
|
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
|
||||||
|
past_key_values[index] = new_key_values
|
||||||
|
|
||||||
|
hidden_states = self.norm_f(hidden_states)
|
||||||
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
class MambaForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.model = MambaModel(config, weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.ByteTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids, past_key_values, attention_mask, return_dict, use_cache
|
||||||
|
)
|
||||||
|
print("🌈 model output done")
|
62
server/text_generation_server/models/mamba.py
Normal file
62
server/text_generation_server/models/mamba.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.mamba_modeling import MambaConfig, MambaForCausalLM
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Mamba(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"EleutherAI/gpt-neox-20b",
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config = MambaConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.bos_token_id = config.bos_token_id
|
||||||
|
tokenizer.eos_token_id = config.eos_token_id
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
config.quantize = quantize
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
model = MambaForCausalLM(config, weights)
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user