mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
This PR adds paligemma modeling code Blog post: https://huggingface.co/blog/paligemma Transformers PR: https://github.com/huggingface/transformers/pull/30814 install the latest changes and run with ```bash # get the weights # text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf # run TGI text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf ``` basic example sending various requests ```python from huggingface_hub import InferenceClient client = InferenceClient("http://127.0.0.1:3000") images = [ "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png", ] prompts = [ "What animal is in this image?", "Name three colors in this image.", "What are 10 colors in this image?", "Where is the cow standing?", "answer en Where is the cow standing?", "Is there a bird in the image?", "Is ther a cow in the image?", "Is there a rabbit in the image?", "how many birds are in the image?", "how many rabbits are in the image?", ] for img in images: print(f"\nImage: {img.split('/')[-1]}") for prompt in prompts: inputs = f"{prompt}\n" json_data = { "inputs": inputs, "parameters": { "max_new_tokens": 30, "do_sample": False, }, } generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False) print([f"{prompt}\n{generated_output}"]) ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
245 lines
6.8 KiB
Python
245 lines
6.8 KiB
Python
import os
|
|
import torch
|
|
|
|
from loguru import logger
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
HAS_FLASH_ATTN = True
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
|
HAS_FLASH_ATTN_V2_ROCM = False
|
|
|
|
if SYSTEM == "xpu":
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
|
|
if window_size_left != -1:
|
|
raise ValueError(
|
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
)
|
|
return ipex.llm.functional.varlen_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
|
|
if SYSTEM in {"cuda", "rocm"}:
|
|
if not torch.cuda.is_available():
|
|
raise ImportError("CUDA is not available")
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
is_sm75 = major == 7 and minor == 5
|
|
is_sm8x = major == 8 and minor >= 0
|
|
is_sm90 = major == 9 and minor == 0
|
|
|
|
HAS_FLASH_ATTN = False
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
|
HAS_FLASH_ATTN_V2_ROCM = False
|
|
try:
|
|
try:
|
|
import flash_attn_2_cuda
|
|
except ImportError:
|
|
architecture_suffix = f"-{SYSTEM}"
|
|
raise ImportError(
|
|
"Flash Attention V2 is not installed.\n"
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
)
|
|
if not (is_sm8x or is_sm90):
|
|
raise ImportError(
|
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
"Flash Attention V2"
|
|
)
|
|
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
|
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
|
except ImportError as e:
|
|
try:
|
|
import flash_attn_cuda
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Flash Attention is not installed.\n"
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
) from e
|
|
|
|
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
|
|
raise ImportError(
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
) from e
|
|
elif SYSTEM == "rocm":
|
|
for idx in range(torch.cuda.device_count()):
|
|
if "MI210" not in torch.cuda.get_device_name(
|
|
idx
|
|
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
raise ImportError(
|
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
)
|
|
|
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
HAS_FLASH_ATTN = True
|
|
|
|
|
|
if HAS_FLASH_ATTN_V2_CUDA:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
causal=True,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
None,
|
|
None,
|
|
None,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
causal,
|
|
window_size_left,
|
|
0,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
elif HAS_FLASH_ATTN_V2_ROCM:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
if window_size_left != -1:
|
|
raise ValueError(
|
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
)
|
|
|
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
elif HAS_FLASH_ATTN:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left != -1:
|
|
raise NotImplementedError(
|
|
"window_size_left is only available with flash attn v2"
|
|
)
|
|
|
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
if k.shape[1] != q.shape[1]:
|
|
# MQA expand
|
|
if k.shape[1] == 1:
|
|
k = k.expand(-1, q.shape[1], -1)
|
|
# Grouped attention reshape
|
|
else:
|
|
original_shape = k.shape
|
|
k = (
|
|
k.unsqueeze(2)
|
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
)
|
|
if v.shape[1] != q.shape[1]:
|
|
# MQA expand
|
|
if v.shape[1] == 1:
|
|
v = v.expand(-1, q.shape[1], -1)
|
|
# Grouped attention reshape
|
|
else:
|
|
original_shape = v.shape
|
|
v = (
|
|
v.unsqueeze(2)
|
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
)
|
|
|
|
return flash_attn_cuda.fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
0,
|
|
None,
|
|
)
|
|
|
|
else:
|
|
raise NotImplementedError("flash attention is not installed")
|