mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
reenable xpu for tgi
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
efb73fcb59
commit
3d35292907
@ -43,6 +43,7 @@ USER root
|
|||||||
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
|
||||||
|
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null
|
||||||
|
|
||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
@ -9,6 +9,8 @@ if SYSTEM == "cuda":
|
|||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
elif SYSTEM == "xpu":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
|
@ -62,7 +62,7 @@ if SYSTEM == "cuda":
|
|||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
dropout_layer_norm = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -5,7 +5,9 @@ from loguru import logger
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.flash_attn_triton import triton_attention
|
|
||||||
|
if SYSTEM != "xpu":
|
||||||
|
from text_generation_server.utils.flash_attn_triton import triton_attention
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
@ -15,43 +17,6 @@ HAS_FLASH_ATTN_V2_ROCM = False
|
|||||||
ROCM_USE_FLASH_ATTN_V2_CK = False
|
ROCM_USE_FLASH_ATTN_V2_CK = False
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise ValueError(
|
|
||||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
return ipex.llm.functional.varlen_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM in {"cuda", "rocm"}:
|
if SYSTEM in {"cuda", "rocm"}:
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
@ -124,8 +89,44 @@ if SYSTEM in {"cuda", "rocm"}:
|
|||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
HAS_FLASH_ATTN = True
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
if SYSTEM == "xpu":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if HAS_FLASH_ATTN_V2_CUDA:
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
return ipex.llm.functional.varlen_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif HAS_FLASH_ATTN_V2_CUDA:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
|
@ -17,7 +17,7 @@ def get_cuda_free_memory(device, memory_fraction):
|
|||||||
return free_memory
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
def get_xpu_free_memory(device):
|
def get_xpu_free_memory(device, memory_fraction):
|
||||||
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
|
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
|
||||||
free_memory = int(total_gpu_memory * 0.5)
|
free_memory = int(total_gpu_memory * 0.5)
|
||||||
return free_memory
|
return free_memory
|
||||||
|
Loading…
Reference in New Issue
Block a user