mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing rocm.
This commit is contained in:
parent
8aece3bd68
commit
908973ee0e
@ -126,40 +126,34 @@ if ENGINE != "triton":
|
|||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
try:
|
if major >= 8:
|
||||||
import flash_attn_cuda
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
|
)
|
||||||
|
elif is_sm75:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
|
||||||
ENGINE = "v1"
|
for idx in range(torch.cuda.device_count()):
|
||||||
logger.info("ROCm: using Flash Attention 1")
|
name = torch.cuda.get_device_name(idx)
|
||||||
except ImportError as e:
|
if "MI210" not in name and "MI250" not in name:
|
||||||
if major >= 8:
|
raise ImportError(
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
raise ImportError(
|
)
|
||||||
"Flash Attention V2 is not installed.\n"
|
raise ImportError(
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
f"AMD GPU with ROCm capability {major} {minor} is not supported"
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
) from e
|
||||||
)
|
|
||||||
elif is_sm75:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
|
|
||||||
for idx in range(torch.cuda.device_count()):
|
|
||||||
name = torch.cuda.get_device_name(idx)
|
|
||||||
if "MI210" not in name and "MI250" not in name:
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
||||||
)
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU with ROCm capability {major} {minor} is not supported"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = ENGINE != "v1"
|
SUPPORTS_WINDOWING = False
|
||||||
if ENGINE == "ck":
|
if ENGINE == "ck":
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
@ -186,17 +180,12 @@ if ENGINE == "ck":
|
|||||||
out,
|
out,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
causal,
|
causal,
|
||||||
window_size_left,
|
|
||||||
0,
|
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -234,62 +223,4 @@ elif ENGINE == "triton":
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
):
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"window_size_left is only available with flash attn v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
||||||
if k.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if k.shape[1] == 1:
|
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = k.shape
|
|
||||||
k = (
|
|
||||||
k.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
if v.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if v.shape[1] == 1:
|
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = v.shape
|
|
||||||
v = (
|
|
||||||
v.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
|
|
||||||
return flash_attn_cuda.fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user