mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Non flash MPT.
This commit is contained in:
parent
2b53d71991
commit
f33ad7ed98
@ -10,6 +10,7 @@ from text_generation_server.models.model import Model
|
|||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
from text_generation_server.models.bloom import BLOOMSharded
|
||||||
|
from text_generation_server.models.mpt import MPTSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.rw import RW
|
from text_generation_server.models.rw import RW
|
||||||
from text_generation_server.models.opt import OPTSharded
|
from text_generation_server.models.opt import OPTSharded
|
||||||
@ -178,6 +179,10 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif model_type == "mpt":
|
||||||
|
return MPTSharded(
|
||||||
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
elif model_type == "gpt_neox":
|
elif model_type == "gpt_neox":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
1141
server/text_generation_server/models/custom_modeling/mpt_modeling.py
Normal file
1141
server/text_generation_server/models/custom_modeling/mpt_modeling.py
Normal file
File diff suppressed because it is too large
Load Diff
74
server/text_generation_server/models/mpt.py
Normal file
74
server/text_generation_server/models/mpt.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoTokenizer, PretrainedConfig
|
||||||
|
from typing import Optional
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
|
MPTForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MPTSharded(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("MPTSharded is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = PretrainedConfig(**config)
|
||||||
|
config.quantize = quantize
|
||||||
|
# config = AutoConfig.from_pretrained(
|
||||||
|
# # model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
# model_id, revision=revision, trust_remote_code=False
|
||||||
|
# )
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
|
||||||
|
config.quantize = quantize
|
||||||
|
model = MPTForCausalLM(config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
@ -31,7 +31,19 @@ def load_layer_norm(cls, prefix, weights, eps):
|
|||||||
return ln
|
return ln
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
with init_empty_weights():
|
||||||
|
ln = cls(weight.shape, eps=eps)
|
||||||
|
|
||||||
|
ln.weight = nn.Parameter(weight)
|
||||||
|
ln.bias = None
|
||||||
|
return ln
|
||||||
|
|
||||||
|
|
||||||
torch.nn.LayerNorm.load = load_layer_norm
|
torch.nn.LayerNorm.load = load_layer_norm
|
||||||
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(nn.Module):
|
class FastLinear(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user