From 1373c185c3ea079de8abba86e9bfdd6a30b1df3a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 21 May 2024 15:48:01 +0200 Subject: [PATCH] Creating doc automatically for supported models. --- .../text_generation_server/models/__init__.py | 177 +++++++++++++++--- update_doc.py | 39 +++- 2 files changed, 187 insertions(+), 29 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9e5676f5..eb2c1158 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,4 +1,5 @@ import torch +import enum import os from loguru import logger @@ -116,6 +117,138 @@ if MAMBA_AVAILABLE: __all__.append(Mamba) +class ModelType(enum.Enum): + MAMBA = { + "type": "ssm", + "name": "Mamba", + "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", + } + GALACTICA = { + "type": "galactica", + "name": "Galactica", + "url": "https://huggingface.co/facebook/galactica-120b", + } + SANTACODER = { + "type": "santacoder", + "name": "SantaCoder", + "url": "https://huggingface.co/bigcode/santacoder", + } + BLOOM = { + "type": "bloom", + "name": "Bloom", + "url": "https://huggingface.co/bigscience/bloom-560m", + } + MPT = { + "type": "mpt", + "name": "Mpt", + "url": "https://huggingface.co/mosaicml/mpt-7b-instruct", + } + GPT2 = { + "type": "gpt2", + "name": "Gpt2", + "url": "https://huggingface.co/openai-community/gpt2", + } + GPT_NEOX = { + "type": "gpt_neox", + "name": "Gpt Neox", + "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", + } + GPT_BIGCODE = { + "type": "gpt_bigcode", + "name": "Gpt Bigcode", + "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", + } + PHI = { + "type": "phi", + "name": "Phi", + "url": "https://huggingface.co/microsoft/phi-1_5", + } + PHI3 = { + "type": "phi3", + "name": "Phi 3", + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + } + LLAMA = { + "type": "llama", + "name": "Llama", + "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + } + BAICHUAN = { + "type": "baichuan", + "name": "Baichuan", + "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", + } + GEMMA = { + "type": "gemma", + "name": "Gemma", + "url": "https://huggingface.co/google/gemma-7b", + } + COHERE = { + "type": "cohere", + "name": "Cohere", + "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus", + } + DRBX = { + "type": "drbx", + "name": "Drbx", + "url": "https://huggingface.co/databricks/dbrx-instruct", + } + FALCON = { + "type": "falcon", + "name": "Falcon", + "url": "https://huggingface.co/tiiuae/falcon-7b-instruct", + } + MISTRAL = { + "type": "mistral", + "name": "Mistral", + "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", + } + MIXTRAL = { + "type": "mixtral", + "name": "Mixtral", + "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1", + } + STARCODER2 = { + "type": "starcoder2", + "name": "StarCoder 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + QWEN2 = { + "type": "qwen2", + "name": "Qwen 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + OPT = { + "type": "opt", + "name": "Opt", + "url": "https://huggingface.co/facebook/opt-6.7b", + } + T5 = { + "type": "t5", + "name": "T5", + "url": "https://huggingface.co/google/flan-t5-xxl", + } + IDEFICS = { + "type": "idefics", + "name": "Idefics", + "url": "https://huggingface.co/HuggingFaceM4/idefics-9b", + } + IDEFICS2 = { + "type": "idefics2", + "name": "Idefics 2", + "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", + } + LLAVA_NEXT = { + "type": "llava_next", + "name": "Llava Next (1.6)", + "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", + } + + +for data in ModelType: + globals()[data.name] = data.value["type"] + + def get_model( model_id: str, revision: Optional[str], @@ -267,7 +400,7 @@ def get_model( else: logger.info(f"Unknown quantization method {method}") - if model_type == "ssm": + if model_type == MAMBA: return Mamba( model_id, revision, @@ -288,8 +421,8 @@ def get_model( ) if ( - model_type == "gpt_bigcode" - or model_type == "gpt2" + model_type == GPT_BIGCODE + or model_type == GPT2 and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: @@ -315,7 +448,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "bloom": + if model_type == BLOOM: return BLOOMSharded( model_id, revision, @@ -324,7 +457,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "mpt": + elif model_type == MPT: return MPTSharded( model_id, revision, @@ -333,7 +466,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "gpt2": + elif model_type == GPT2: if FLASH_ATTENTION: return FlashGPT2( model_id, @@ -354,7 +487,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "gpt_neox": + elif model_type == GPT_NEOX: if FLASH_ATTENTION: return FlashNeoXSharded( model_id, @@ -383,7 +516,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "phi": + elif model_type == PHI: if FLASH_ATTENTION: return FlashPhi( model_id, @@ -418,7 +551,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3": + elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: return FlashLlama( model_id, @@ -439,7 +572,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "gemma": + if model_type == GEMMA: if FLASH_ATTENTION: return FlashGemma( model_id, @@ -461,7 +594,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "cohere": + if model_type == COHERE: if FLASH_ATTENTION: return FlashCohere( model_id, @@ -483,7 +616,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "dbrx": + if model_type == DRBX: if FLASH_ATTENTION: return FlashDbrx( model_id, @@ -505,7 +638,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: + if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: if FLASH_ATTENTION: if config_dict.get("alibi", False): @@ -539,7 +672,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "mistral": + if model_type == MISTRAL: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -566,7 +699,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "mixtral": + if model_type == MIXTRAL: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -593,7 +726,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "starcoder2": + if model_type == STARCODER2: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -621,7 +754,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "qwen2": + if model_type == QWEN2: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -647,7 +780,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "opt": + if model_type == OPT: return OPTSharded( model_id, revision, @@ -657,7 +790,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "t5": + if model_type == T5: return T5Sharded( model_id, revision, @@ -666,7 +799,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "idefics": + if model_type == IDEFICS: if FLASH_ATTENTION: return IDEFICSSharded( model_id, @@ -678,7 +811,7 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "idefics2": + if model_type == IDEFICS2: if FLASH_ATTENTION: return Idefics2( model_id, @@ -703,7 +836,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "llava_next": + if model_type == LLAVA_NEXT: if FLASH_ATTENTION: return LlavaNext( model_id, diff --git a/update_doc.py b/update_doc.py index 6127418c..1ab4488b 100644 --- a/update_doc.py +++ b/update_doc.py @@ -1,13 +1,9 @@ import subprocess import argparse +import ast -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--check", action="store_true") - - args = parser.parse_args() - +def check_cli(check: bool): output = subprocess.check_output(["text-generation-launcher", "--help"]).decode( "utf-8" ) @@ -41,7 +37,7 @@ def main(): block = [] filename = "docs/source/basic_tutorials/launcher.md" - if args.check: + if check: with open(filename, "r") as f: doc = f.read() if doc != final_doc: @@ -60,5 +56,34 @@ def main(): f.write(final_doc) +def check_supported_models(check: bool): + filename = "server/text_generation_server/models/__init__.py" + with open(filename, "r") as f: + tree = ast.parse(f.read()) + + enum_def = [ + x for x in tree.body if isinstance(x, ast.ClassDef) and x.name == "ModelType" + ][0] + _locals = {} + _globals = {} + exec(f"import enum\n{ast.unparse(enum_def)}", _globals, _locals) + ModelType = _locals["ModelType"] + for data in ModelType: + print(data.name) + print(f" type: {data.value['type']}") + print(f" name: {data.value['name']}") + print(f" url: {data.value['url']}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--check", action="store_true") + + args = parser.parse_args() + + # check_cli(args.check) + check_supported_models(args.check) + + if __name__ == "__main__": main()