transformers flash llm/vlm enabling in ipex (#3152)

* transformers flash llm/vlm enabling in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* ipex cpu could also support in function

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-04-15 17:08:01 +08:00 committed by GitHub
parent 449cee49ca
commit 459fbdebe3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 10 deletions

View File

@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
# Text Generation Inference base env # Text Generation Inference base env
ENV HF_HOME=/data \ ENV HF_HOME=/data \
@ -98,9 +98,7 @@ ENV HF_HOME=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto

View File

@ -201,7 +201,8 @@ except ImportError as e:
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
__all__.append(Mamba) __all__.append(Mamba)
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == "ipex"
try: try:
from text_generation_server.models.transformers_flash_causal_lm import ( from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM, TransformersFlashCausalLM,

View File

@ -12,7 +12,7 @@ from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -115,8 +115,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif SYSTEM == "ipex":
device = torch.device("xpu") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: else:
raise ValueError( raise ValueError(

View File

@ -14,6 +14,7 @@ from text_generation_server.layers.attention import paged_attention, attention,
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -174,8 +175,11 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif SYSTEM == "ipex":
device = torch.device("xpu") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: else:
raise ValueError( raise ValueError(

View File

@ -73,6 +73,13 @@ def initialize_torch_distributed():
if SYSTEM == "ipex": if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
assert (
WORLD_SIZE <= torch.xpu.device_count()
), "Each process is one xpu"
device = RANK % torch.xpu.device_count()
torch.xpu.set_device(device)
ipex.distributed.init_process_group( ipex.distributed.init_process_group(
backend="ccl", backend="ccl",
world_size=WORLD_SIZE, world_size=WORLD_SIZE,