mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Creating doc automatically for supported models.
This commit is contained in:
parent
fc0eaffc81
commit
1373c185c3
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import enum
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -116,6 +117,138 @@ if MAMBA_AVAILABLE:
|
|||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(enum.Enum):
|
||||||
|
MAMBA = {
|
||||||
|
"type": "ssm",
|
||||||
|
"name": "Mamba",
|
||||||
|
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||||
|
}
|
||||||
|
GALACTICA = {
|
||||||
|
"type": "galactica",
|
||||||
|
"name": "Galactica",
|
||||||
|
"url": "https://huggingface.co/facebook/galactica-120b",
|
||||||
|
}
|
||||||
|
SANTACODER = {
|
||||||
|
"type": "santacoder",
|
||||||
|
"name": "SantaCoder",
|
||||||
|
"url": "https://huggingface.co/bigcode/santacoder",
|
||||||
|
}
|
||||||
|
BLOOM = {
|
||||||
|
"type": "bloom",
|
||||||
|
"name": "Bloom",
|
||||||
|
"url": "https://huggingface.co/bigscience/bloom-560m",
|
||||||
|
}
|
||||||
|
MPT = {
|
||||||
|
"type": "mpt",
|
||||||
|
"name": "Mpt",
|
||||||
|
"url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
|
||||||
|
}
|
||||||
|
GPT2 = {
|
||||||
|
"type": "gpt2",
|
||||||
|
"name": "Gpt2",
|
||||||
|
"url": "https://huggingface.co/openai-community/gpt2",
|
||||||
|
}
|
||||||
|
GPT_NEOX = {
|
||||||
|
"type": "gpt_neox",
|
||||||
|
"name": "Gpt Neox",
|
||||||
|
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
|
||||||
|
}
|
||||||
|
GPT_BIGCODE = {
|
||||||
|
"type": "gpt_bigcode",
|
||||||
|
"name": "Gpt Bigcode",
|
||||||
|
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
|
||||||
|
}
|
||||||
|
PHI = {
|
||||||
|
"type": "phi",
|
||||||
|
"name": "Phi",
|
||||||
|
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||||
|
}
|
||||||
|
PHI3 = {
|
||||||
|
"type": "phi3",
|
||||||
|
"name": "Phi 3",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
}
|
||||||
|
LLAMA = {
|
||||||
|
"type": "llama",
|
||||||
|
"name": "Llama",
|
||||||
|
"url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
}
|
||||||
|
BAICHUAN = {
|
||||||
|
"type": "baichuan",
|
||||||
|
"name": "Baichuan",
|
||||||
|
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
}
|
||||||
|
GEMMA = {
|
||||||
|
"type": "gemma",
|
||||||
|
"name": "Gemma",
|
||||||
|
"url": "https://huggingface.co/google/gemma-7b",
|
||||||
|
}
|
||||||
|
COHERE = {
|
||||||
|
"type": "cohere",
|
||||||
|
"name": "Cohere",
|
||||||
|
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
|
||||||
|
}
|
||||||
|
DRBX = {
|
||||||
|
"type": "drbx",
|
||||||
|
"name": "Drbx",
|
||||||
|
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||||
|
}
|
||||||
|
FALCON = {
|
||||||
|
"type": "falcon",
|
||||||
|
"name": "Falcon",
|
||||||
|
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
|
||||||
|
}
|
||||||
|
MISTRAL = {
|
||||||
|
"type": "mistral",
|
||||||
|
"name": "Mistral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
|
}
|
||||||
|
MIXTRAL = {
|
||||||
|
"type": "mixtral",
|
||||||
|
"name": "Mixtral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||||
|
}
|
||||||
|
STARCODER2 = {
|
||||||
|
"type": "starcoder2",
|
||||||
|
"name": "StarCoder 2",
|
||||||
|
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||||
|
}
|
||||||
|
QWEN2 = {
|
||||||
|
"type": "qwen2",
|
||||||
|
"name": "Qwen 2",
|
||||||
|
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||||
|
}
|
||||||
|
OPT = {
|
||||||
|
"type": "opt",
|
||||||
|
"name": "Opt",
|
||||||
|
"url": "https://huggingface.co/facebook/opt-6.7b",
|
||||||
|
}
|
||||||
|
T5 = {
|
||||||
|
"type": "t5",
|
||||||
|
"name": "T5",
|
||||||
|
"url": "https://huggingface.co/google/flan-t5-xxl",
|
||||||
|
}
|
||||||
|
IDEFICS = {
|
||||||
|
"type": "idefics",
|
||||||
|
"name": "Idefics",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||||
|
}
|
||||||
|
IDEFICS2 = {
|
||||||
|
"type": "idefics2",
|
||||||
|
"name": "Idefics 2",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||||
|
}
|
||||||
|
LLAVA_NEXT = {
|
||||||
|
"type": "llava_next",
|
||||||
|
"name": "Llava Next (1.6)",
|
||||||
|
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for data in ModelType:
|
||||||
|
globals()[data.name] = data.value["type"]
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
@ -267,7 +400,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Unknown quantization method {method}")
|
logger.info(f"Unknown quantization method {method}")
|
||||||
|
|
||||||
if model_type == "ssm":
|
if model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -288,8 +421,8 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_type == "gpt_bigcode"
|
model_type == GPT_BIGCODE
|
||||||
or model_type == "gpt2"
|
or model_type == GPT2
|
||||||
and model_id.startswith("bigcode/")
|
and model_id.startswith("bigcode/")
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
@ -315,7 +448,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "bloom":
|
if model_type == BLOOM:
|
||||||
return BLOOMSharded(
|
return BLOOMSharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -324,7 +457,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == "mpt":
|
elif model_type == MPT:
|
||||||
return MPTSharded(
|
return MPTSharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -333,7 +466,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == "gpt2":
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGPT2(
|
return FlashGPT2(
|
||||||
model_id,
|
model_id,
|
||||||
@ -354,7 +487,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == "gpt_neox":
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashNeoXSharded(
|
return FlashNeoXSharded(
|
||||||
model_id,
|
model_id,
|
||||||
@ -383,7 +516,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "phi":
|
elif model_type == PHI:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashPhi(
|
return FlashPhi(
|
||||||
model_id,
|
model_id,
|
||||||
@ -418,7 +551,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
|
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashLlama(
|
||||||
model_id,
|
model_id,
|
||||||
@ -439,7 +572,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type == "gemma":
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGemma(
|
return FlashGemma(
|
||||||
model_id,
|
model_id,
|
||||||
@ -461,7 +594,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "cohere":
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCohere(
|
return FlashCohere(
|
||||||
model_id,
|
model_id,
|
||||||
@ -483,7 +616,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "dbrx":
|
if model_type == DRBX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashDbrx(
|
return FlashDbrx(
|
||||||
model_id,
|
model_id,
|
||||||
@ -505,7 +638,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config_dict.get("alibi", False):
|
if config_dict.get("alibi", False):
|
||||||
@ -539,7 +672,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mistral":
|
if model_type == MISTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
@ -566,7 +699,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mixtral":
|
if model_type == MIXTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
@ -593,7 +726,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "starcoder2":
|
if model_type == STARCODER2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
@ -621,7 +754,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "qwen2":
|
if model_type == QWEN2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
@ -647,7 +780,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "opt":
|
if model_type == OPT:
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -657,7 +790,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "t5":
|
if model_type == T5:
|
||||||
return T5Sharded(
|
return T5Sharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -666,7 +799,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type == "idefics":
|
if model_type == IDEFICS:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return IDEFICSSharded(
|
return IDEFICSSharded(
|
||||||
model_id,
|
model_id,
|
||||||
@ -678,7 +811,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == "idefics2":
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return Idefics2(
|
return Idefics2(
|
||||||
model_id,
|
model_id,
|
||||||
@ -703,7 +836,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == "llava_next":
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return LlavaNext(
|
return LlavaNext(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import argparse
|
import argparse
|
||||||
|
import ast
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def check_cli(check: bool):
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--check", action="store_true")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode(
|
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
@ -41,7 +37,7 @@ def main():
|
|||||||
block = []
|
block = []
|
||||||
|
|
||||||
filename = "docs/source/basic_tutorials/launcher.md"
|
filename = "docs/source/basic_tutorials/launcher.md"
|
||||||
if args.check:
|
if check:
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
doc = f.read()
|
doc = f.read()
|
||||||
if doc != final_doc:
|
if doc != final_doc:
|
||||||
@ -60,5 +56,34 @@ def main():
|
|||||||
f.write(final_doc)
|
f.write(final_doc)
|
||||||
|
|
||||||
|
|
||||||
|
def check_supported_models(check: bool):
|
||||||
|
filename = "server/text_generation_server/models/__init__.py"
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
tree = ast.parse(f.read())
|
||||||
|
|
||||||
|
enum_def = [
|
||||||
|
x for x in tree.body if isinstance(x, ast.ClassDef) and x.name == "ModelType"
|
||||||
|
][0]
|
||||||
|
_locals = {}
|
||||||
|
_globals = {}
|
||||||
|
exec(f"import enum\n{ast.unparse(enum_def)}", _globals, _locals)
|
||||||
|
ModelType = _locals["ModelType"]
|
||||||
|
for data in ModelType:
|
||||||
|
print(data.name)
|
||||||
|
print(f" type: {data.value['type']}")
|
||||||
|
print(f" name: {data.value['name']}")
|
||||||
|
print(f" url: {data.value['url']}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--check", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# check_cli(args.check)
|
||||||
|
check_supported_models(args.check)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user