From 459fbdebe3bfc055bb71ad0913a498e93f30ebe1 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 15 Apr 2025 17:08:01 +0800 Subject: [PATCH] transformers flash llm/vlm enabling in ipex (#3152) * transformers flash llm/vlm enabling in xpu Signed-off-by: Wang, Yi A * ipex cpu could also support in function Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- Dockerfile_intel | 6 ++---- server/text_generation_server/models/__init__.py | 3 ++- .../models/transformers_flash_causal_lm.py | 9 ++++++--- .../models/transformers_flash_vlm.py | 8 ++++++-- server/text_generation_server/utils/dist.py | 7 +++++++ 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 5bf7632ce..b2a905ec9 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/ RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200 # Text Generation Inference base env ENV HF_HOME=/data \ @@ -98,9 +98,7 @@ ENV HF_HOME=/data \ WORKDIR /usr/src -RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu - -RUN pip install triton-xpu==3.2.0b1 --no-cache-dir +RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu # Install server COPY proto proto diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 291ee5fba..93a2f8bf4 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -201,7 +201,8 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) -FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() +FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == "ipex" + try: from text_generation_server.models.transformers_flash_causal_lm import ( TransformersFlashCausalLM, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 77659dd0f..8c21ed58c 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -12,7 +12,7 @@ from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION - +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -115,8 +115,11 @@ class TransformersFlashCausalLM(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device("xpu") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = default_dtype if dtype is None else dtype else: raise ValueError( diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index a7beb68b3..280fa0bd3 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -14,6 +14,7 @@ from text_generation_server.layers.attention import paged_attention, attention, from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION import torch.nn.functional as F +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -174,8 +175,11 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device("xpu") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = default_dtype if dtype is None else dtype else: raise ValueError( diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 613c4784b..4a1bef6df 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -73,6 +73,13 @@ def initialize_torch_distributed(): if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + assert ( + WORLD_SIZE <= torch.xpu.device_count() + ), "Each process is one xpu" + device = RANK % torch.xpu.device_count() + torch.xpu.set_device(device) + ipex.distributed.init_process_group( backend="ccl", world_size=WORLD_SIZE,