add llama to readme

This commit is contained in:
OlivierDehaene 2023-04-11 16:08:06 +02:00
parent c2beaa279e
commit d7548aef9b
2 changed files with 14 additions and 13 deletions

View File

@ -51,16 +51,14 @@ to power LLMs api-inference widgets.
- Log probabilities
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
## Officially supported architectures
## Optimized architectures
- [BLOOM](https://huggingface.co/bigscience/bloom)
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
- [FLAN-UL2](https://huggingface.co/google/flan-ul2)
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
- [Llama](https://github.com/facebookresearch/llama)
Other architectures are supported on a best effort basis using:

View File

@ -21,12 +21,9 @@ try:
from text_generation_server.models.flash_santacoder import FlashSantacoder
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
FLASH_ATTENTION = (
torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1
)
FLASH_ATTENTION = torch.cuda.is_available()
except ImportError:
if int(os.environ.get("FLASH_ATTENTION", 0)) == 1:
logger.exception("Could not import Flash Attention models")
logger.exception("Could not import Flash Attention enabled models")
FLASH_ATTENTION = False
__all__ = [
@ -48,6 +45,12 @@ if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashLlama)
__all__.append(FlashLlamaSharded)
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \
"or install flash attention with `cd server && make install install-flash-attention`"
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
@ -61,7 +64,7 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
@ -98,7 +101,7 @@ def get_model(
if FLASH_ATTENTION:
return FlashLlamaSharded(model_id, revision, quantize=quantize)
raise NotImplementedError(
"sharded is not supported for llama when FLASH_ATTENTION=0"
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")
)
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM