From d7548aef9b9bc871b3520dae85ddf47cf4eb1880 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Apr 2023 16:08:06 +0200 Subject: [PATCH] add llama to readme --- README.md | 10 ++++------ .../text_generation_server/models/__init__.py | 17 ++++++++++------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index ad938bdd..bc77fd4c 100644 --- a/README.md +++ b/README.md @@ -51,16 +51,14 @@ to power LLMs api-inference widgets. - Log probabilities - Production ready (distributed tracing with Open Telemetry, Prometheus metrics) -## Officially supported architectures +## Optimized architectures - [BLOOM](https://huggingface.co/bigscience/bloom) -- [BLOOMZ](https://huggingface.co/bigscience/bloomz) -- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [Galactica](https://huggingface.co/facebook/galactica-120b) - [SantaCoder](https://huggingface.co/bigcode/santacoder) -- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) -- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) -- [FLAN-UL2](https://huggingface.co/google/flan-ul2) +- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) +- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl) +- [Llama](https://github.com/facebookresearch/llama) Other architectures are supported on a best effort basis using: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bc802df9..1e06b6dc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -21,12 +21,9 @@ try: from text_generation_server.models.flash_santacoder import FlashSantacoder from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded - FLASH_ATTENTION = ( - torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1 - ) + FLASH_ATTENTION = torch.cuda.is_available() except ImportError: - if int(os.environ.get("FLASH_ATTENTION", 0)) == 1: - logger.exception("Could not import Flash Attention models") + logger.exception("Could not import Flash Attention enabled models") FLASH_ATTENTION = False __all__ = [ @@ -48,6 +45,12 @@ if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) __all__.append(FlashSantacoder) + __all__.append(FlashLlama) + __all__.append(FlashLlamaSharded) + +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \ + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \ + "or install flash attention with `cd server && make install install-flash-attention`" # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -61,7 +64,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: if "facebook/galactica" in model_id: if sharded: @@ -98,7 +101,7 @@ def get_model( if FLASH_ATTENTION: return FlashLlamaSharded(model_id, revision, quantize=quantize) raise NotImplementedError( - "sharded is not supported for llama when FLASH_ATTENTION=0" + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama") ) else: llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM