mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Moving to SYSTEM
enum.
This commit is contained in:
parent
c84718c8b6
commit
34b9289c3d
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
IS_ROCM_SYSTEM,
|
SYSTEM,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -10,7 +10,7 @@ except Exception:
|
|||||||
major = 1
|
major = 1
|
||||||
|
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
|
@ -2,9 +2,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
IS_CUDA_SYSTEM,
|
SYSTEM,
|
||||||
IS_ROCM_SYSTEM,
|
|
||||||
IS_XPU_SYSTEM,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -35,19 +33,60 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
|||||||
torch.nn.LayerNorm.load = load_layer_norm
|
torch.nn.LayerNorm.load = load_layer_norm
|
||||||
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
from vllm import layernorm_ops
|
|
||||||
elif IS_XPU_SYSTEM:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
else:
|
|
||||||
dropout_layer_norm = None
|
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if IS_XPU_SYSTEM:
|
if hidden_states.shape[-1] > 8192:
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
normed_hidden_states,
|
||||||
|
residual,
|
||||||
|
*rest,
|
||||||
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||||
|
|
||||||
|
elif SYSTEM == "xpu":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
res_out = hidden_states
|
res_out = hidden_states
|
||||||
out = ipex.llm.functional.add_layer_norm(
|
out = ipex.llm.functional.add_layer_norm(
|
||||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||||
@ -55,7 +94,19 @@ class FastLayerNorm(nn.LayerNorm):
|
|||||||
if residual is not None:
|
if residual is not None:
|
||||||
res_out = residual
|
res_out = residual
|
||||||
return out, res_out
|
return out, res_out
|
||||||
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
|
||||||
|
|
||||||
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if SYSTEM == "xpu":
|
||||||
|
res_out = hidden_states
|
||||||
|
out = ipex.llm.functional.add_layer_norm(
|
||||||
|
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||||
|
)
|
||||||
|
if residual is not None:
|
||||||
|
res_out = residual
|
||||||
|
return out, res_out
|
||||||
|
elif hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -102,7 +153,7 @@ class FastRMSNorm(nn.Module):
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if IS_XPU_SYSTEM:
|
if SYSTEM == "xpu":
|
||||||
residual_out = hidden_states
|
residual_out = hidden_states
|
||||||
out = ipex.llm.functional.add_rms_norm(
|
out = ipex.llm.functional.add_rms_norm(
|
||||||
residual,
|
residual,
|
||||||
@ -131,7 +182,7 @@ class FastRMSNorm(nn.Module):
|
|||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
return self.weight * hidden_states, residual
|
return self.weight * hidden_states, residual
|
||||||
elif IS_CUDA_SYSTEM:
|
elif SYSTEM == "cuda":
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
(
|
(
|
||||||
normed_hidden_states,
|
normed_hidden_states,
|
||||||
@ -158,7 +209,7 @@ class FastRMSNorm(nn.Module):
|
|||||||
res = hidden_states
|
res = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, res
|
return normed_hidden_states, res
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(torch.nn.Module):
|
class FastLinear(torch.nn.Module):
|
||||||
@ -126,7 +127,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||||
)
|
)
|
||||||
if IS_ROCM_SYSTEM:
|
if SYSTEM == "rocm":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
IS_CUDA_SYSTEM,
|
|
||||||
IS_ROCM_SYSTEM,
|
|
||||||
)
|
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
from vllm import pos_encoding_ops
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +47,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Such controlflows may add some overhead.
|
# Such controlflows may add some overhead.
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
q1 = query[..., :rotary_dim]
|
q1 = query[..., :rotary_dim]
|
||||||
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||||
@ -61,7 +58,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
|
|
||||||
@ -69,7 +66,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
elif IS_XPU_SYSTEM:
|
elif SYSTEM == "xpu":
|
||||||
ipex.llm.functional.rotary_embedding(
|
ipex.llm.functional.rotary_embedding(
|
||||||
query, key, sin, cos, query.size(-1), True
|
query, key, sin, cos, query.size(-1), True
|
||||||
)
|
)
|
||||||
@ -223,7 +220,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Return cos and sin for the asked position ids
|
Return cos and sin for the asked position ids
|
||||||
"""
|
"""
|
||||||
if IS_ROCM_SYSTEM:
|
if SYSTEM == "rocm":
|
||||||
# For RoCm, we always use float cos/sin to avoid a cast.
|
# For RoCm, we always use float cos/sin to avoid a cast.
|
||||||
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
||||||
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||||
|
@ -2,7 +2,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
BLOCK_SIZE: int = 16
|
BLOCK_SIZE: int = 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
@ -25,7 +25,7 @@ class CacheManager:
|
|||||||
self.repeat_slots = repeat_slots
|
self.repeat_slots = repeat_slots
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
if IS_XPU_SYSTEM:
|
if SYSTEM == "xpu":
|
||||||
x = 1
|
x = 1
|
||||||
else:
|
else:
|
||||||
x = self.block_size // element_size
|
x = self.block_size // element_size
|
||||||
|
@ -26,18 +26,22 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
else:
|
else:
|
||||||
dropout_layer_norm = None
|
dropout_layer_norm = None
|
||||||
@ -52,7 +56,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Such controlflows may add some overhead.
|
# Such controlflows may add some overhead.
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
|
|
||||||
q1 = query[..., ::2]
|
q1 = query[..., ::2]
|
||||||
@ -64,7 +68,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
|||||||
k2 = key[..., 1::2]
|
k2 = key[..., 1::2]
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
from vllm import pos_encoding_ops
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
@ -90,7 +94,7 @@ class CohereLayerNorm(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
|
||||||
hidden_states = hidden_states.reshape(
|
hidden_states = hidden_states.reshape(
|
||||||
-1, self.weight.shape[0], self.weight.shape[1]
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
)
|
)
|
||||||
|
@ -21,21 +21,26 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if not IS_XPU_SYSTEM:
|
if SYSTEM != "xpu":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
FastLayerNorm,
|
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,9 +24,9 @@ import torch.distributed
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if not IS_XPU_SYSTEM:
|
if SYSTEM != "xpu":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
@ -36,14 +36,18 @@ from loguru import logger
|
|||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
FastRMSNorm,
|
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MixtralConfig(PretrainedConfig):
|
class MixtralConfig(PretrainedConfig):
|
||||||
|
@ -55,12 +55,14 @@ from text_generation_server.layers import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
from vllm import layernorm_ops
|
from vllm import layernorm_ops
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -373,7 +375,7 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
return self.weight * hidden_states
|
return self.weight * hidden_states
|
||||||
elif IS_CUDA_SYSTEM:
|
elif SYSTEM == "cuda":
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
unwrap = False
|
unwrap = False
|
||||||
if len(hidden_states.shape) > 2:
|
if len(hidden_states.shape) > 2:
|
||||||
@ -405,7 +407,7 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||||
|
|
||||||
return normed_hidden_states
|
return normed_hidden_states
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
@ -12,10 +12,6 @@ from dataclasses import dataclass
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
Tokens,
|
Tokens,
|
||||||
@ -32,13 +28,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
|||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
IS_CUDA_SYSTEM,
|
empty_cache,
|
||||||
IS_ROCM_SYSTEM,
|
synchronize,
|
||||||
IS_XPU_SYSTEM,
|
get_free_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
@ -757,10 +754,8 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
empty_cache()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif IS_XPU_SYSTEM:
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
try:
|
try:
|
||||||
cache_manager = set_cache_manager(
|
cache_manager = set_cache_manager(
|
||||||
batch.blocks,
|
batch.blocks,
|
||||||
@ -780,10 +775,7 @@ class FlashCausalLM(Model):
|
|||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
synchronize(self.device)
|
||||||
torch.cuda.synchronize(self.device)
|
|
||||||
elif IS_XPU_SYSTEM:
|
|
||||||
torch.xpu.synchronize(self.device)
|
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
@ -791,20 +783,7 @@ class FlashCausalLM(Model):
|
|||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
|
||||||
total_gpu_memory = torch.cuda.get_device_properties(
|
|
||||||
self.device
|
|
||||||
).total_memory
|
|
||||||
|
|
||||||
free_memory = max(
|
|
||||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
|
||||||
)
|
|
||||||
elif IS_XPU_SYSTEM:
|
|
||||||
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
|
||||||
free_memory = int(total_gpu_memory * 0.5)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashModel is only available on GPU")
|
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
|
@ -18,7 +18,7 @@ from text_generation_server.utils import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IS_XPU_SYSTEM:
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
|
@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__)
|
|||||||
# Will be set in init
|
# Will be set in init
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IS_XPU_SYSTEM:
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
|
@ -14,7 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IS_XPU_SYSTEM:
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
|
@ -15,7 +15,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IS_XPU_SYSTEM:
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
|
@ -18,7 +18,7 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IS_XPU_SYSTEM:
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
|
@ -2,13 +2,8 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import math
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
IS_CUDA_SYSTEM,
|
|
||||||
IS_ROCM_SYSTEM,
|
|
||||||
IS_XPU_SYSTEM,
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
@ -16,69 +11,9 @@ HAS_FLASH_ATTN = True
|
|||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
|
|
||||||
if IS_XPU_SYSTEM:
|
if SYSTEM == "xpu":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise ImportError("CUDA is not available")
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
|
||||||
is_sm90 = major == 9 and minor == 0
|
|
||||||
|
|
||||||
HAS_FLASH_ATTN = False
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
except ImportError:
|
|
||||||
architecture_suffix = ""
|
|
||||||
if IS_CUDA_SYSTEM:
|
|
||||||
architecture_suffix = "-cuda"
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
architecture_suffix = "-rocm"
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention V2 is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
||||||
)
|
|
||||||
if not (is_sm8x or is_sm90):
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
||||||
"Flash Attention V2"
|
|
||||||
)
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
|
||||||
except ImportError as e:
|
|
||||||
try:
|
|
||||||
import flash_attn_cuda
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
||||||
) from e
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
for idx in range(torch.cuda.device_count()):
|
|
||||||
if "MI210" not in torch.cuda.get_device_name(
|
|
||||||
idx
|
|
||||||
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
||||||
HAS_FLASH_ATTN = True
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -92,7 +27,6 @@ def attention(
|
|||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
if IS_XPU_SYSTEM:
|
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
@ -114,7 +48,77 @@ def attention(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if SYSTEM in {"cuda", "rocm"}:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise ImportError("CUDA is not available")
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
|
HAS_FLASH_ATTN = False
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
except ImportError:
|
||||||
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
|
)
|
||||||
|
if not (is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
|
"Flash Attention V2"
|
||||||
|
)
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
||||||
|
except ImportError as e:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
for idx in range(torch.cuda.device_count()):
|
||||||
|
if "MI210" not in torch.cuda.get_device_name(
|
||||||
|
idx
|
||||||
|
) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
|
||||||
if HAS_FLASH_ATTN_V2_CUDA:
|
if HAS_FLASH_ATTN_V2_CUDA:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -136,7 +140,21 @@ def attention(
|
|||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
@ -159,7 +177,19 @@ def attention(
|
|||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN:
|
elif HAS_FLASH_ATTN:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
@ -209,4 +239,5 @@ def attention(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
raise NotImplementedError("flash attention is not installed")
|
raise NotImplementedError("flash attention is not installed")
|
||||||
|
@ -10,6 +10,32 @@ def is_xpu_available():
|
|||||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
|
||||||
|
|
||||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
def get_cuda_free_memory(device, memory_fraction):
|
||||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
total_free_memory, _ = torch.cuda.mem_get_info(device)
|
||||||
IS_XPU_SYSTEM = is_xpu_available()
|
total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
|
||||||
|
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
|
||||||
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
|
def get_xpu_free_memory(device):
|
||||||
|
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
|
||||||
|
free_memory = int(total_gpu_memory * 0.5)
|
||||||
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM = None
|
||||||
|
if torch.version.hip is not None:
|
||||||
|
SYSTEM = "rocm"
|
||||||
|
empty_cache = torch.cuda.empty_cache
|
||||||
|
synchronize = torch.cuda.synchronize
|
||||||
|
get_free_memory = get_cuda_free_memory
|
||||||
|
elif torch.version.cuda is not None:
|
||||||
|
SYSTEM = "cuda"
|
||||||
|
empty_cache = torch.cuda.empty_cache
|
||||||
|
synchronize = torch.cuda.synchronize
|
||||||
|
get_free_memory = get_cuda_free_memory
|
||||||
|
elif is_xpu_available():
|
||||||
|
SYSTEM = "xpu"
|
||||||
|
empty_cache = torch.xpu.empty_cache
|
||||||
|
synchronize = torch.xpu.synchronize
|
||||||
|
get_free_memory = get_xpu_free_memory
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
IS_CUDA_SYSTEM,
|
|
||||||
IS_ROCM_SYSTEM,
|
|
||||||
IS_XPU_SYSTEM,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
if IS_XPU_SYSTEM:
|
if SYSTEM == "xpu":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
@ -18,17 +14,17 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
from vllm._C import cache_ops
|
from vllm._C import cache_ops
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
)
|
)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
elif IS_XPU_SYSTEM:
|
elif SYSTEM == "xpu":
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots
|
key, value, key_cache, value_cache, slots
|
||||||
)
|
)
|
||||||
@ -68,7 +64,7 @@ def attention(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
if IS_XPU_SYSTEM:
|
if SYSTEM == "xpu":
|
||||||
query = query.contiguous()
|
query = query.contiguous()
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
@ -91,7 +87,7 @@ def attention(
|
|||||||
# to parallelize.
|
# to parallelize.
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
@ -109,7 +105,7 @@ def attention(
|
|||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
|
|
||||||
attention_ops.paged_attention_v1(
|
attention_ops.paged_attention_v1(
|
||||||
@ -143,7 +139,7 @@ def attention(
|
|||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
if SYSTEM == "cuda":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
@ -164,7 +160,7 @@ def attention(
|
|||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif SYSTEM == "rocm":
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
|
|
||||||
attention_ops.paged_attention_v2(
|
attention_ops.paged_attention_v2(
|
||||||
|
Loading…
Reference in New Issue
Block a user