From 23a1cb0511005495b27f6fe95455c74a645ee008 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 5 Mar 2024 18:29:31 -0800 Subject: [PATCH] align to ipex llm ops Signed-off-by: Wang, Yi A --- .../utils/flash_attn.py | 14 +++++--- server/text_generation_server/utils/layers.py | 33 +++++++++++++------ .../utils/paged_attention.py | 18 +++++++--- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index a8c8f75e..e49447b9 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -4,7 +4,11 @@ import torch from loguru import logger import math -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -12,6 +16,9 @@ HAS_FLASH_ATTN = True HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: if not torch.cuda.is_available(): raise ImportError("CUDA is not available") @@ -90,7 +97,7 @@ def attention( raise ValueError( f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." ) - return torch.xpu.varlen_fwd( + return ipex.llm.modules.VarlenAttention.apply( q, k, v, @@ -104,10 +111,9 @@ def attention( False, True, False, - None + None, ) - if HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( q, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c66a8d2c..95d777a6 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,7 +18,14 @@ except ImportError: from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) + +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex HAS_AWQ = True try: @@ -816,7 +823,13 @@ try: if residual is not None: hidden_states += residual residual = hidden_states - out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) + out = ipex.llm.modules.FastLayerNorm.apply( + hidden_states, + self.normalized_shape, + self.eps, + self.weight, + self.bias, + ) return out, residual elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if residual is not None: @@ -868,8 +881,11 @@ try: if residual is not None: hidden_states += residual residual = hidden_states - out = torch.ops.torch_ipex.rms_norm( - hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon + out = ipex.llm.modules.RMSNorm.apply( + hidden_states, + [hidden_states.size(-1)], + self.weight, + self.variance_epsilon, ) return out[0], residual elif hidden_states.shape[-1] > 8192: @@ -999,15 +1015,14 @@ try: # Inplace operation, updating query and key. pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) elif IS_XPU_SYSTEM: - sin = sin.expand(query.shape) - cos = cos.expand(query.shape) - torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key) + ipex.llm.modules.RotaryEmbedding.apply( + query, key, sin, cos, query.size(-1), True + ) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." ) - @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) @@ -1123,8 +1138,6 @@ try: cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) - if IS_XPU_SYSTEM: - return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2) # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 0130cd0c..cff718c7 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,8 +1,14 @@ import torch -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) _PARTITION_SIZE = 512 +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex def reshape_and_cache( @@ -23,7 +29,9 @@ def reshape_and_cache( cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) elif IS_XPU_SYSTEM: - torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots) + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) else: raise ValueError("vllm is not supported on your system") @@ -67,18 +75,18 @@ def attention( # to parallelize. if IS_XPU_SYSTEM: query = query.contiguous() - return torch.xpu.IpexPaged_attention( + return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, key_cache, value_cache, kv_head_mapping, + softmax_scale, block_tables, input_lengths, - softmax_scale, block_size, max_s, - None + None, ) if use_v1: