mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
# What does this PR do? Hotfixes: - Uses `model_type`=`gpt_bigcode` for more general usage. - Hotfixes linked lm_head vs wte_embedding (safetensors file do not contain the key, correctly when the file is sharded, where as pytorch copies the tensor) <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
171 lines
6.0 KiB
Python
171 lines
6.0 KiB
Python
import torch
|
|
|
|
from loguru import logger
|
|
from transformers import AutoConfig
|
|
from transformers.models.auto import modeling_auto
|
|
from typing import Optional
|
|
|
|
from text_generation_server.models.model import Model
|
|
from text_generation_server.models.causal_lm import CausalLM
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
|
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
|
from text_generation_server.models.opt import OPT, OPTSharded
|
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
|
from text_generation_server.models.santacoder import SantaCoder
|
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
|
from text_generation_server.models.t5 import T5Sharded
|
|
|
|
try:
|
|
if torch.cuda.is_available():
|
|
major, minor = torch.cuda.get_device_capability()
|
|
is_sm75 = major == 7 and minor == 5
|
|
is_sm8x = major == 8 and minor >= 0
|
|
is_sm90 = major == 9 and minor == 0
|
|
|
|
supported = is_sm75 or is_sm8x or is_sm90
|
|
if not supported:
|
|
raise ImportError(
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
)
|
|
|
|
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
|
from text_generation_server.models.flash_llama import (
|
|
FlashLlama,
|
|
FlashLlamaSharded,
|
|
)
|
|
from text_generation_server.models.flash_santacoder import (
|
|
FlashSantacoder,
|
|
FlashSantacoderSharded,
|
|
)
|
|
|
|
FLASH_ATTENTION = True
|
|
else:
|
|
FLASH_ATTENTION = False
|
|
except ImportError:
|
|
logger.opt(exception=True).warning(
|
|
"Could not import Flash Attention enabled models"
|
|
)
|
|
FLASH_ATTENTION = False
|
|
|
|
__all__ = [
|
|
"Model",
|
|
"BLOOM",
|
|
"BLOOMSharded",
|
|
"CausalLM",
|
|
"FlashCausalLM",
|
|
"Galactica",
|
|
"GalacticaSharded",
|
|
"GPTNeoxSharded",
|
|
"Seq2SeqLM",
|
|
"SantaCoder",
|
|
"OPT",
|
|
"OPTSharded",
|
|
"T5Sharded",
|
|
"get_model",
|
|
]
|
|
|
|
if FLASH_ATTENTION:
|
|
__all__.append(FlashNeoX)
|
|
__all__.append(FlashNeoXSharded)
|
|
__all__.append(FlashSantacoder)
|
|
__all__.append(FlashSantacoderSharded)
|
|
__all__.append(FlashLlama)
|
|
__all__.append(FlashLlamaSharded)
|
|
|
|
FLASH_ATT_ERROR_MESSAGE = (
|
|
"{} requires Flash Attention CUDA kernels to be installed.\n"
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
)
|
|
|
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
# in PyTorch 1.12 and later.
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# Disable gradients
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
def get_model(
|
|
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
|
|
) -> Model:
|
|
if "facebook/galactica" in model_id:
|
|
if sharded:
|
|
return GalacticaSharded(model_id, revision, quantize=quantize)
|
|
else:
|
|
return Galactica(model_id, revision, quantize=quantize)
|
|
|
|
if model_id.startswith("bigcode/"):
|
|
if sharded:
|
|
if not FLASH_ATTENTION:
|
|
raise NotImplementedError(
|
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
|
)
|
|
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
|
else:
|
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
|
return santacoder_cls(model_id, revision, quantize=quantize)
|
|
|
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
|
model_type = config.model_type
|
|
|
|
if model_type == "gpt_bigcode":
|
|
if sharded:
|
|
if not FLASH_ATTENTION:
|
|
raise NotImplementedError(
|
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
|
)
|
|
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
|
else:
|
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
|
return santacoder_cls(model_id, revision, quantize=quantize)
|
|
|
|
if model_type == "bloom":
|
|
if sharded:
|
|
return BLOOMSharded(model_id, revision, quantize=quantize)
|
|
else:
|
|
return BLOOM(model_id, revision, quantize=quantize)
|
|
|
|
if model_type == "gpt_neox":
|
|
if sharded:
|
|
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
|
|
return neox_cls(model_id, revision, quantize=quantize)
|
|
else:
|
|
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
|
|
return neox_cls(model_id, revision, quantize=quantize)
|
|
|
|
if model_type == "llama":
|
|
if sharded:
|
|
if FLASH_ATTENTION:
|
|
return FlashLlamaSharded(model_id, revision, quantize=quantize)
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
|
|
else:
|
|
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
|
return llama_cls(model_id, revision, quantize=quantize)
|
|
|
|
if config.model_type == "opt":
|
|
if sharded:
|
|
return OPTSharded(model_id, revision, quantize=quantize)
|
|
else:
|
|
return OPT(model_id, revision, quantize=quantize)
|
|
|
|
if model_type == "t5":
|
|
if sharded:
|
|
return T5Sharded(model_id, revision, quantize=quantize)
|
|
else:
|
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
|
|
|
if sharded:
|
|
raise ValueError("sharded is not supported for AutoModel")
|
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
return CausalLM(model_id, revision, quantize=quantize)
|
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|