mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
add skinny kernel and merge fixes
This commit is contained in:
parent
058162685f
commit
59fd0cbdff
@ -152,9 +152,6 @@ ENV HIP_FORCE_DEV_KERNARG=1
|
|||||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
|
||||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
|
||||||
ENV VLLM_MOE_PADDING=0
|
|
||||||
|
|
||||||
FROM base AS kernel-builder
|
FROM base AS kernel-builder
|
||||||
|
|
||||||
@ -245,6 +242,13 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base-copy
|
FROM base-copy
|
||||||
|
|
||||||
|
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||||
|
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||||
|
ENV VLLM_MOE_PADDING=0
|
||||||
|
ENV ATTENTION=paged
|
||||||
|
ENV USE_PREFIX_CACHING=0
|
||||||
|
ENV ROCM_USE_SKINNY_GEMM=1
|
||||||
|
|
||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
@ -45,7 +45,6 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
num_kv_heads: int,
|
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
@ -62,7 +62,6 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
num_kv_heads: int,
|
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
@ -50,9 +50,8 @@ def paged_attention(
|
|||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
num_kv_heads: int,
|
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
@ -76,6 +75,7 @@ def paged_attention(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
|
||||||
|
num_kv_heads = key_cache.shape[1]
|
||||||
gqa_ratio = num_heads // num_kv_heads
|
gqa_ratio = num_heads // num_kv_heads
|
||||||
use_custom = (
|
use_custom = (
|
||||||
custom_attn_available
|
custom_attn_available
|
||||||
@ -92,7 +92,7 @@ def paged_attention(
|
|||||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||||
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = input_lengths.input_lengths
|
input_lengths = seqlen.input_lengths
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
@ -220,10 +220,10 @@ if ENGINE == "ck":
|
|||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
@ -237,17 +237,17 @@ if ENGINE == "ck":
|
|||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_k,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
@ -264,26 +264,27 @@ elif ENGINE == "triton":
|
|||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(q)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_k,
|
||||||
causal,
|
causal,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
)
|
)
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
import os
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
||||||
from vllm import _custom_C
|
"true",
|
||||||
except Exception as e:
|
"1",
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
)
|
||||||
|
|
||||||
|
if ROCM_USE_SKINNY_GEMM:
|
||||||
|
try:
|
||||||
|
from vllm import _custom_C
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(torch.nn.Module):
|
class FastLinear(torch.nn.Module):
|
||||||
@ -48,6 +55,14 @@ class FastLinearROCm(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
self.cu_count = torch.cuda.get_device_properties(
|
||||||
|
device="cuda"
|
||||||
|
).multi_processor_count
|
||||||
|
self.use_skinny_gemm = (
|
||||||
|
ROCM_USE_SKINNY_GEMM
|
||||||
|
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
@ -62,9 +77,9 @@ class FastLinearROCm(torch.nn.Module):
|
|||||||
bias = self.bias
|
bias = self.bias
|
||||||
|
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
self.use_skinny_gemm
|
||||||
and inp.numel() // inp.shape[-1] == 1
|
|
||||||
and inp.dtype == torch.float16
|
and inp.dtype == torch.float16
|
||||||
|
and inp.shape[-1] % 8 == 0
|
||||||
):
|
):
|
||||||
batched = False
|
batched = False
|
||||||
inp_shape = inp.shape
|
inp_shape = inp.shape
|
||||||
@ -73,13 +88,16 @@ class FastLinearROCm(torch.nn.Module):
|
|||||||
inp = inp.view(-1, inp_shape[-1])
|
inp = inp.view(-1, inp_shape[-1])
|
||||||
batched = True
|
batched = True
|
||||||
|
|
||||||
m, k = weight.shape[0], inp_shape[1]
|
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
||||||
out = torch.empty(
|
if m > 8 and n <= 4:
|
||||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
out = torch.empty(
|
||||||
)
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
)
|
||||||
_custom_C.LLMM1(weight, inp, out, 8)
|
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||||
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||||
|
out = torch.empty(
|
||||||
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||||
|
)
|
||||||
_custom_C.LLMM1(weight, inp, out, 4)
|
_custom_C.LLMM1(weight, inp, out, 4)
|
||||||
else:
|
else:
|
||||||
out = F.linear(inp, weight)
|
out = F.linear(inp, weight)
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -314,7 +315,6 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -353,7 +354,6 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -363,8 +364,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -380,7 +381,6 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove padding.
|
# Remove padding.
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -25,7 +26,6 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -257,7 +257,6 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -25,7 +26,6 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -249,7 +249,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -18,13 +18,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -248,7 +248,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -237,7 +238,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -236,7 +237,6 @@ class MistralAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -275,8 +276,8 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -293,7 +294,6 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -26,7 +27,6 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
kv_cache[0] if PAGED_KV else qkv[:, 1],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
kv_cache[1] if PAGED_KV else qkv[:, 2],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -189,7 +189,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -25,7 +26,6 @@ from text_generation_server.layers.layernorm import (
|
|||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -211,7 +211,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -21,7 +22,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -155,7 +155,6 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -224,7 +224,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -326,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(),
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(),
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -343,7 +342,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(
|
return self.dense(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -22,7 +23,6 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
kv_cache[0] if PAGED_KV else key_value[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
kv_cache[1] if PAGED_KV else key_value[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -310,7 +310,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -260,7 +260,6 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
self.num_key_value_heads,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -1379,6 +1379,7 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
max_s = seqlen
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
prefix_lengths=prefix_lens_tensor,
|
||||||
@ -1396,7 +1397,7 @@ class FlashCausalLM(Model):
|
|||||||
block_tables=None,
|
block_tables=None,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=max_s,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
@ -52,6 +53,12 @@ CUDA_GRAPHS = cuda_graphs
|
|||||||
# index in all cases.
|
# index in all cases.
|
||||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||||
|
|
||||||
|
PAGED_KV: bool
|
||||||
|
if SYSTEM in {"rocm", "ipex"}:
|
||||||
|
PAGED_KV = False
|
||||||
|
else:
|
||||||
|
PAGED_KV = True
|
||||||
|
|
||||||
|
|
||||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||||
global ADAPTER_TO_INDEX
|
global ADAPTER_TO_INDEX
|
||||||
|
Loading…
Reference in New Issue
Block a user