Creating doc automatically for supported models.

This commit is contained in:
Nicolas Patry 2024-05-21 15:48:01 +02:00
parent fc0eaffc81
commit 1373c185c3
2 changed files with 187 additions and 29 deletions

View File

@ -1,4 +1,5 @@
import torch
import enum
import os
from loguru import logger
@ -116,6 +117,138 @@ if MAMBA_AVAILABLE:
__all__.append(Mamba)
class ModelType(enum.Enum):
MAMBA = {
"type": "ssm",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
"url": "https://huggingface.co/facebook/galactica-120b",
}
SANTACODER = {
"type": "santacoder",
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
BLOOM = {
"type": "bloom",
"name": "Bloom",
"url": "https://huggingface.co/bigscience/bloom-560m",
}
MPT = {
"type": "mpt",
"name": "Mpt",
"url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",
"url": "https://huggingface.co/openai-community/gpt2",
}
GPT_NEOX = {
"type": "gpt_neox",
"name": "Gpt Neox",
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
}
GPT_BIGCODE = {
"type": "gpt_bigcode",
"name": "Gpt Bigcode",
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
}
DRBX = {
"type": "drbx",
"name": "Drbx",
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
FALCON = {
"type": "falcon",
"name": "Falcon",
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
}
MIXTRAL = {
"type": "mixtral",
"name": "Mixtral",
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
}
STARCODER2 = {
"type": "starcoder2",
"name": "StarCoder 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
OPT = {
"type": "opt",
"name": "Opt",
"url": "https://huggingface.co/facebook/opt-6.7b",
}
T5 = {
"type": "t5",
"name": "T5",
"url": "https://huggingface.co/google/flan-t5-xxl",
}
IDEFICS = {
"type": "idefics",
"name": "Idefics",
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
}
for data in ModelType:
globals()[data.name] = data.value["type"]
def get_model(
model_id: str,
revision: Optional[str],
@ -267,7 +400,7 @@ def get_model(
else:
logger.info(f"Unknown quantization method {method}")
if model_type == "ssm":
if model_type == MAMBA:
return Mamba(
model_id,
revision,
@ -288,8 +421,8 @@ def get_model(
)
if (
model_type == "gpt_bigcode"
or model_type == "gpt2"
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION:
@ -315,7 +448,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "bloom":
if model_type == BLOOM:
return BLOOMSharded(
model_id,
revision,
@ -324,7 +457,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "mpt":
elif model_type == MPT:
return MPTSharded(
model_id,
revision,
@ -333,7 +466,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt2":
elif model_type == GPT2:
if FLASH_ATTENTION:
return FlashGPT2(
model_id,
@ -354,7 +487,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt_neox":
elif model_type == GPT_NEOX:
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
@ -383,7 +516,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "phi":
elif model_type == PHI:
if FLASH_ATTENTION:
return FlashPhi(
model_id,
@ -418,7 +551,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION:
return FlashLlama(
model_id,
@ -439,7 +572,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "gemma":
if model_type == GEMMA:
if FLASH_ATTENTION:
return FlashGemma(
model_id,
@ -461,7 +594,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "cohere":
if model_type == COHERE:
if FLASH_ATTENTION:
return FlashCohere(
model_id,
@ -483,7 +616,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "dbrx":
if model_type == DRBX:
if FLASH_ATTENTION:
return FlashDbrx(
model_id,
@ -505,7 +638,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded:
if FLASH_ATTENTION:
if config_dict.get("alibi", False):
@ -539,7 +672,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "mistral":
if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
@ -566,7 +699,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "mixtral":
if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
@ -593,7 +726,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "starcoder2":
if model_type == STARCODER2:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
@ -621,7 +754,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "qwen2":
if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
@ -647,7 +780,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "opt":
if model_type == OPT:
return OPTSharded(
model_id,
revision,
@ -657,7 +790,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "t5":
if model_type == T5:
return T5Sharded(
model_id,
revision,
@ -666,7 +799,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "idefics":
if model_type == IDEFICS:
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
@ -678,7 +811,7 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "idefics2":
if model_type == IDEFICS2:
if FLASH_ATTENTION:
return Idefics2(
model_id,
@ -703,7 +836,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next":
if model_type == LLAVA_NEXT:
if FLASH_ATTENTION:
return LlavaNext(
model_id,

View File

@ -1,13 +1,9 @@
import subprocess
import argparse
import ast
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
args = parser.parse_args()
def check_cli(check: bool):
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode(
"utf-8"
)
@ -41,7 +37,7 @@ def main():
block = []
filename = "docs/source/basic_tutorials/launcher.md"
if args.check:
if check:
with open(filename, "r") as f:
doc = f.read()
if doc != final_doc:
@ -60,5 +56,34 @@ def main():
f.write(final_doc)
def check_supported_models(check: bool):
filename = "server/text_generation_server/models/__init__.py"
with open(filename, "r") as f:
tree = ast.parse(f.read())
enum_def = [
x for x in tree.body if isinstance(x, ast.ClassDef) and x.name == "ModelType"
][0]
_locals = {}
_globals = {}
exec(f"import enum\n{ast.unparse(enum_def)}", _globals, _locals)
ModelType = _locals["ModelType"]
for data in ModelType:
print(data.name)
print(f" type: {data.value['type']}")
print(f" name: {data.value['name']}")
print(f" url: {data.value['url']}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
args = parser.parse_args()
# check_cli(args.check)
check_supported_models(args.check)
if __name__ == "__main__":
main()