Non flash MPT.

This commit is contained in:
Nicolas Patry 2023-06-30 09:52:49 +00:00
parent 2b53d71991
commit f33ad7ed98
4 changed files with 1232 additions and 0 deletions

View File

@ -10,6 +10,7 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPTSharded
@ -178,6 +179,10 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "mpt":
return MPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,74 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig
from typing import Optional
from huggingface_hub import hf_hub_download
import json
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class MPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("MPTSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
# config = AutoConfig.from_pretrained(
# # model_id, revision=revision, trust_remote_code=trust_remote_code
# model_id, revision=revision, trust_remote_code=False
# )
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
config.quantize = quantize
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -31,7 +31,19 @@ def load_layer_norm(cls, prefix, weights, eps):
return ln
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = nn.Parameter(weight)
ln.bias = None
return ln
torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLinear(nn.Module):