mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Creating doc automatically for supported models.
This commit is contained in:
parent
fc0eaffc81
commit
1373c185c3
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import enum
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
@ -116,6 +117,138 @@ if MAMBA_AVAILABLE:
|
||||
__all__.append(Mamba)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
MAMBA = {
|
||||
"type": "ssm",
|
||||
"name": "Mamba",
|
||||
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||
}
|
||||
GALACTICA = {
|
||||
"type": "galactica",
|
||||
"name": "Galactica",
|
||||
"url": "https://huggingface.co/facebook/galactica-120b",
|
||||
}
|
||||
SANTACODER = {
|
||||
"type": "santacoder",
|
||||
"name": "SantaCoder",
|
||||
"url": "https://huggingface.co/bigcode/santacoder",
|
||||
}
|
||||
BLOOM = {
|
||||
"type": "bloom",
|
||||
"name": "Bloom",
|
||||
"url": "https://huggingface.co/bigscience/bloom-560m",
|
||||
}
|
||||
MPT = {
|
||||
"type": "mpt",
|
||||
"name": "Mpt",
|
||||
"url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
|
||||
}
|
||||
GPT2 = {
|
||||
"type": "gpt2",
|
||||
"name": "Gpt2",
|
||||
"url": "https://huggingface.co/openai-community/gpt2",
|
||||
}
|
||||
GPT_NEOX = {
|
||||
"type": "gpt_neox",
|
||||
"name": "Gpt Neox",
|
||||
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
|
||||
}
|
||||
GPT_BIGCODE = {
|
||||
"type": "gpt_bigcode",
|
||||
"name": "Gpt Bigcode",
|
||||
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
|
||||
}
|
||||
PHI = {
|
||||
"type": "phi",
|
||||
"name": "Phi",
|
||||
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||
}
|
||||
PHI3 = {
|
||||
"type": "phi3",
|
||||
"name": "Phi 3",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||
}
|
||||
LLAMA = {
|
||||
"type": "llama",
|
||||
"name": "Llama",
|
||||
"url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
}
|
||||
BAICHUAN = {
|
||||
"type": "baichuan",
|
||||
"name": "Baichuan",
|
||||
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
|
||||
}
|
||||
GEMMA = {
|
||||
"type": "gemma",
|
||||
"name": "Gemma",
|
||||
"url": "https://huggingface.co/google/gemma-7b",
|
||||
}
|
||||
COHERE = {
|
||||
"type": "cohere",
|
||||
"name": "Cohere",
|
||||
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
|
||||
}
|
||||
DRBX = {
|
||||
"type": "drbx",
|
||||
"name": "Drbx",
|
||||
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||
}
|
||||
FALCON = {
|
||||
"type": "falcon",
|
||||
"name": "Falcon",
|
||||
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
|
||||
}
|
||||
MISTRAL = {
|
||||
"type": "mistral",
|
||||
"name": "Mistral",
|
||||
"url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
|
||||
}
|
||||
MIXTRAL = {
|
||||
"type": "mixtral",
|
||||
"name": "Mixtral",
|
||||
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||
}
|
||||
STARCODER2 = {
|
||||
"type": "starcoder2",
|
||||
"name": "StarCoder 2",
|
||||
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||
}
|
||||
QWEN2 = {
|
||||
"type": "qwen2",
|
||||
"name": "Qwen 2",
|
||||
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||
}
|
||||
OPT = {
|
||||
"type": "opt",
|
||||
"name": "Opt",
|
||||
"url": "https://huggingface.co/facebook/opt-6.7b",
|
||||
}
|
||||
T5 = {
|
||||
"type": "t5",
|
||||
"name": "T5",
|
||||
"url": "https://huggingface.co/google/flan-t5-xxl",
|
||||
}
|
||||
IDEFICS = {
|
||||
"type": "idefics",
|
||||
"name": "Idefics",
|
||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||
}
|
||||
IDEFICS2 = {
|
||||
"type": "idefics2",
|
||||
"name": "Idefics 2",
|
||||
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||
}
|
||||
LLAVA_NEXT = {
|
||||
"type": "llava_next",
|
||||
"name": "Llava Next (1.6)",
|
||||
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||
}
|
||||
|
||||
|
||||
for data in ModelType:
|
||||
globals()[data.name] = data.value["type"]
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
@ -267,7 +400,7 @@ def get_model(
|
||||
else:
|
||||
logger.info(f"Unknown quantization method {method}")
|
||||
|
||||
if model_type == "ssm":
|
||||
if model_type == MAMBA:
|
||||
return Mamba(
|
||||
model_id,
|
||||
revision,
|
||||
@ -288,8 +421,8 @@ def get_model(
|
||||
)
|
||||
|
||||
if (
|
||||
model_type == "gpt_bigcode"
|
||||
or model_type == "gpt2"
|
||||
model_type == GPT_BIGCODE
|
||||
or model_type == GPT2
|
||||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
@ -315,7 +448,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "bloom":
|
||||
if model_type == BLOOM:
|
||||
return BLOOMSharded(
|
||||
model_id,
|
||||
revision,
|
||||
@ -324,7 +457,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "mpt":
|
||||
elif model_type == MPT:
|
||||
return MPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
@ -333,7 +466,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "gpt2":
|
||||
elif model_type == GPT2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGPT2(
|
||||
model_id,
|
||||
@ -354,7 +487,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "gpt_neox":
|
||||
elif model_type == GPT_NEOX:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashNeoXSharded(
|
||||
model_id,
|
||||
@ -383,7 +516,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "phi":
|
||||
elif model_type == PHI:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashPhi(
|
||||
model_id,
|
||||
@ -418,7 +551,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
|
||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
model_id,
|
||||
@ -439,7 +572,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type == "gemma":
|
||||
if model_type == GEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGemma(
|
||||
model_id,
|
||||
@ -461,7 +594,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "cohere":
|
||||
if model_type == COHERE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCohere(
|
||||
model_id,
|
||||
@ -483,7 +616,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "dbrx":
|
||||
if model_type == DRBX:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashDbrx(
|
||||
model_id,
|
||||
@ -505,7 +638,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
if config_dict.get("alibi", False):
|
||||
@ -539,7 +672,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "mistral":
|
||||
if model_type == MISTRAL:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
@ -566,7 +699,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "mixtral":
|
||||
if model_type == MIXTRAL:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
@ -593,7 +726,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "starcoder2":
|
||||
if model_type == STARCODER2:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
@ -621,7 +754,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "qwen2":
|
||||
if model_type == QWEN2:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
@ -647,7 +780,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "opt":
|
||||
if model_type == OPT:
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
@ -657,7 +790,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "t5":
|
||||
if model_type == T5:
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
revision,
|
||||
@ -666,7 +799,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type == "idefics":
|
||||
if model_type == IDEFICS:
|
||||
if FLASH_ATTENTION:
|
||||
return IDEFICSSharded(
|
||||
model_id,
|
||||
@ -678,7 +811,7 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == "idefics2":
|
||||
if model_type == IDEFICS2:
|
||||
if FLASH_ATTENTION:
|
||||
return Idefics2(
|
||||
model_id,
|
||||
@ -703,7 +836,7 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == "llava_next":
|
||||
if model_type == LLAVA_NEXT:
|
||||
if FLASH_ATTENTION:
|
||||
return LlavaNext(
|
||||
model_id,
|
||||
|
@ -1,13 +1,9 @@
|
||||
import subprocess
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def check_cli(check: bool):
|
||||
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode(
|
||||
"utf-8"
|
||||
)
|
||||
@ -41,7 +37,7 @@ def main():
|
||||
block = []
|
||||
|
||||
filename = "docs/source/basic_tutorials/launcher.md"
|
||||
if args.check:
|
||||
if check:
|
||||
with open(filename, "r") as f:
|
||||
doc = f.read()
|
||||
if doc != final_doc:
|
||||
@ -60,5 +56,34 @@ def main():
|
||||
f.write(final_doc)
|
||||
|
||||
|
||||
def check_supported_models(check: bool):
|
||||
filename = "server/text_generation_server/models/__init__.py"
|
||||
with open(filename, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
enum_def = [
|
||||
x for x in tree.body if isinstance(x, ast.ClassDef) and x.name == "ModelType"
|
||||
][0]
|
||||
_locals = {}
|
||||
_globals = {}
|
||||
exec(f"import enum\n{ast.unparse(enum_def)}", _globals, _locals)
|
||||
ModelType = _locals["ModelType"]
|
||||
for data in ModelType:
|
||||
print(data.name)
|
||||
print(f" type: {data.value['type']}")
|
||||
print(f" name: {data.value['name']}")
|
||||
print(f" url: {data.value['url']}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# check_cli(args.check)
|
||||
check_supported_models(args.check)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user