mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
align to ipex llm ops
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
515a0edebe
commit
23a1cb0511
@ -4,7 +4,11 @@ import torch
|
||||
from loguru import logger
|
||||
import math
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
@ -12,6 +16,9 @@ HAS_FLASH_ATTN = True
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
HAS_FLASH_ATTN_V2_ROCM = False
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
if not torch.cuda.is_available():
|
||||
raise ImportError("CUDA is not available")
|
||||
@ -90,7 +97,7 @@ def attention(
|
||||
raise ValueError(
|
||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
return torch.xpu.varlen_fwd(
|
||||
return ipex.llm.modules.VarlenAttention.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -104,10 +111,9 @@ def attention(
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
None
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
if HAS_FLASH_ATTN_V2_CUDA:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
|
@ -18,7 +18,14 @@ except ImportError:
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
@ -816,7 +823,13 @@ try:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
out = ipex.llm.modules.FastLayerNorm.apply(
|
||||
hidden_states,
|
||||
self.normalized_shape,
|
||||
self.eps,
|
||||
self.weight,
|
||||
self.bias,
|
||||
)
|
||||
return out, residual
|
||||
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
if residual is not None:
|
||||
@ -868,8 +881,11 @@ try:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
out = torch.ops.torch_ipex.rms_norm(
|
||||
hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon
|
||||
out = ipex.llm.modules.RMSNorm.apply(
|
||||
hidden_states,
|
||||
[hidden_states.size(-1)],
|
||||
self.weight,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out[0], residual
|
||||
elif hidden_states.shape[-1] > 8192:
|
||||
@ -999,15 +1015,14 @@ try:
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
elif IS_XPU_SYSTEM:
|
||||
sin = sin.expand(query.shape)
|
||||
cos = cos.expand(query.shape)
|
||||
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key)
|
||||
ipex.llm.modules.RotaryEmbedding.apply(
|
||||
query, key, sin, cos, query.size(-1), True
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
@ -1123,8 +1138,6 @@ try:
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2)
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
|
@ -1,8 +1,14 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
@ -23,7 +29,9 @@ def reshape_and_cache(
|
||||
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
elif IS_XPU_SYSTEM:
|
||||
torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
else:
|
||||
raise ValueError("vllm is not supported on your system")
|
||||
|
||||
@ -67,18 +75,18 @@ def attention(
|
||||
# to parallelize.
|
||||
if IS_XPU_SYSTEM:
|
||||
query = query.contiguous()
|
||||
return torch.xpu.IpexPaged_attention(
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
softmax_scale,
|
||||
block_size,
|
||||
max_s,
|
||||
None
|
||||
None,
|
||||
)
|
||||
|
||||
if use_v1:
|
||||
|
Loading…
Reference in New Issue
Block a user