align to ipex llm ops

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-03-05 18:29:31 -08:00 committed by Nicolas Patry
parent 515a0edebe
commit 23a1cb0511
3 changed files with 46 additions and 19 deletions

View File

@ -4,7 +4,11 @@ import torch
from loguru import logger from loguru import logger
import math import math
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
@ -12,6 +16,9 @@ HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False HAS_FLASH_ATTN_V2_ROCM = False
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise ImportError("CUDA is not available") raise ImportError("CUDA is not available")
@ -90,7 +97,7 @@ def attention(
raise ValueError( raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
) )
return torch.xpu.varlen_fwd( return ipex.llm.modules.VarlenAttention.apply(
q, q,
k, k,
v, v,
@ -104,10 +111,9 @@ def attention(
False, False,
True, True,
False, False,
None None,
) )
if HAS_FLASH_ATTN_V2_CUDA: if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,

View File

@ -18,7 +18,14 @@ except ImportError:
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
HAS_AWQ = True HAS_AWQ = True
try: try:
@ -816,7 +823,13 @@ try:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) out = ipex.llm.modules.FastLayerNorm.apply(
hidden_states,
self.normalized_shape,
self.eps,
self.weight,
self.bias,
)
return out, residual return out, residual
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if residual is not None: if residual is not None:
@ -868,8 +881,11 @@ try:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
out = torch.ops.torch_ipex.rms_norm( out = ipex.llm.modules.RMSNorm.apply(
hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon hidden_states,
[hidden_states.size(-1)],
self.weight,
self.variance_epsilon,
) )
return out[0], residual return out[0], residual
elif hidden_states.shape[-1] > 8192: elif hidden_states.shape[-1] > 8192:
@ -999,15 +1015,14 @@ try:
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IS_XPU_SYSTEM: elif IS_XPU_SYSTEM:
sin = sin.expand(query.shape) ipex.llm.modules.RotaryEmbedding.apply(
cos = cos.expand(query.shape) query, key, sin, cos, query.size(-1), True
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key) )
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
) )
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
@ -1123,8 +1138,6 @@ try:
cos = torch.index_select(self._cos_cached, 0, position_ids) cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
if IS_XPU_SYSTEM:
return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)

View File

@ -1,8 +1,14 @@
import torch import torch
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
def reshape_and_cache( def reshape_and_cache(
@ -23,7 +29,9 @@ def reshape_and_cache(
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif IS_XPU_SYSTEM: elif IS_XPU_SYSTEM:
torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots) ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else: else:
raise ValueError("vllm is not supported on your system") raise ValueError("vllm is not supported on your system")
@ -67,18 +75,18 @@ def attention(
# to parallelize. # to parallelize.
if IS_XPU_SYSTEM: if IS_XPU_SYSTEM:
query = query.contiguous() query = query.contiguous()
return torch.xpu.IpexPaged_attention( return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
key_cache, key_cache,
value_cache, value_cache,
kv_head_mapping, kv_head_mapping,
softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
softmax_scale,
block_size, block_size,
max_s, max_s,
None None,
) )
if use_v1: if use_v1: