mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add llama to readme
This commit is contained in:
parent
c2beaa279e
commit
d7548aef9b
10
README.md
10
README.md
@ -51,16 +51,14 @@ to power LLMs api-inference widgets.
|
|||||||
- Log probabilities
|
- Log probabilities
|
||||||
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
|
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
|
||||||
|
|
||||||
## Officially supported architectures
|
## Optimized architectures
|
||||||
|
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
||||||
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
|
||||||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
|
||||||
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
||||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||||
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
|
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
|
||||||
- [FLAN-UL2](https://huggingface.co/google/flan-ul2)
|
- [Llama](https://github.com/facebookresearch/llama)
|
||||||
|
|
||||||
Other architectures are supported on a best effort basis using:
|
Other architectures are supported on a best effort basis using:
|
||||||
|
|
||||||
|
@ -21,12 +21,9 @@ try:
|
|||||||
from text_generation_server.models.flash_santacoder import FlashSantacoder
|
from text_generation_server.models.flash_santacoder import FlashSantacoder
|
||||||
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
|
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
|
||||||
|
|
||||||
FLASH_ATTENTION = (
|
FLASH_ATTENTION = torch.cuda.is_available()
|
||||||
torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if int(os.environ.get("FLASH_ATTENTION", 0)) == 1:
|
logger.exception("Could not import Flash Attention enabled models")
|
||||||
logger.exception("Could not import Flash Attention models")
|
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -48,6 +45,12 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashNeoX)
|
__all__.append(FlashNeoX)
|
||||||
__all__.append(FlashNeoXSharded)
|
__all__.append(FlashNeoXSharded)
|
||||||
__all__.append(FlashSantacoder)
|
__all__.append(FlashSantacoder)
|
||||||
|
__all__.append(FlashLlama)
|
||||||
|
__all__.append(FlashLlamaSharded)
|
||||||
|
|
||||||
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
@ -61,7 +64,7 @@ torch.set_grad_enabled(False)
|
|||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
if sharded:
|
if sharded:
|
||||||
@ -98,7 +101,7 @@ def get_model(
|
|||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlamaSharded(model_id, revision, quantize=quantize)
|
return FlashLlamaSharded(model_id, revision, quantize=quantize)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"sharded is not supported for llama when FLASH_ATTENTION=0"
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
||||||
|
Loading…
Reference in New Issue
Block a user