mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
transformers flash llm/vlm enabling in xpu
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
24bec29ffc
commit
50282e3cc1
@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
|
|||||||
|
|
||||||
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
@ -98,9 +98,7 @@ ENV HF_HOME=/data \
|
|||||||
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
|
RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
|
||||||
|
|
||||||
RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
|
@ -201,7 +201,9 @@ except ImportError as e:
|
|||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
|
|
||||||
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available()
|
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or (
|
||||||
|
hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||||
TransformersFlashCausalLM,
|
TransformersFlashCausalLM,
|
||||||
|
@ -116,7 +116,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device("xpu")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -175,7 +175,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device("xpu")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -73,6 +73,13 @@ def initialize_torch_distributed():
|
|||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
assert (
|
||||||
|
WORLD_SIZE <= torch.xpu.device_count()
|
||||||
|
), "Each process is one xpu"
|
||||||
|
device = RANK % torch.xpu.device_count()
|
||||||
|
torch.xpu.set_device(device)
|
||||||
|
|
||||||
ipex.distributed.init_process_group(
|
ipex.distributed.init_process_group(
|
||||||
backend="ccl",
|
backend="ccl",
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
|
Loading…
Reference in New Issue
Block a user