mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
add llama to readme
This commit is contained in:
parent
c2beaa279e
commit
d7548aef9b
10
README.md
10
README.md
@ -51,16 +51,14 @@ to power LLMs api-inference widgets.
|
||||
- Log probabilities
|
||||
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
|
||||
|
||||
## Officially supported architectures
|
||||
## Optimized architectures
|
||||
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
||||
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
||||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
|
||||
- [FLAN-UL2](https://huggingface.co/google/flan-ul2)
|
||||
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
|
||||
- [Llama](https://github.com/facebookresearch/llama)
|
||||
|
||||
Other architectures are supported on a best effort basis using:
|
||||
|
||||
|
@ -21,12 +21,9 @@ try:
|
||||
from text_generation_server.models.flash_santacoder import FlashSantacoder
|
||||
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
|
||||
|
||||
FLASH_ATTENTION = (
|
||||
torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1
|
||||
)
|
||||
FLASH_ATTENTION = torch.cuda.is_available()
|
||||
except ImportError:
|
||||
if int(os.environ.get("FLASH_ATTENTION", 0)) == 1:
|
||||
logger.exception("Could not import Flash Attention models")
|
||||
logger.exception("Could not import Flash Attention enabled models")
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
__all__ = [
|
||||
@ -48,6 +45,12 @@ if FLASH_ATTENTION:
|
||||
__all__.append(FlashNeoX)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashSantacoder)
|
||||
__all__.append(FlashLlama)
|
||||
__all__.append(FlashLlamaSharded)
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
@ -61,7 +64,7 @@ torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
) -> Model:
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
@ -98,7 +101,7 @@ def get_model(
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlamaSharded(model_id, revision, quantize=quantize)
|
||||
raise NotImplementedError(
|
||||
"sharded is not supported for llama when FLASH_ATTENTION=0"
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")
|
||||
)
|
||||
else:
|
||||
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
||||
|
Loading…
Reference in New Issue
Block a user