diff --git a/Dockerfile_intel b/Dockerfile_intel index cb0e84bb..33a7ed63 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,3 +1,5 @@ +ARG PLATFORM=xpu + FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef WORKDIR /usr/src @@ -36,7 +38,8 @@ RUN cargo build --profile release-opt # Text Generation Inference base image for Intel -FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base + +FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu USER root # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it @@ -88,8 +91,72 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca # Install launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher -# Final image -FROM base +# Text Generation Inference base image for Intel-cpu +FROM ubuntu:22.04 as cpu + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + make \ + g++ \ + git \ + wget + +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu + +WORKDIR /usr/src + +RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl +RUN pip install intel_extension_for_pytorch-2.3.100+git0eb3473-cp310-cp310-linux_x86_64.whl +RUN pip install oneccl_bind_pt==2.3.0 -f https://developer.intel.com/ipex-whl-stable-cpu + +RUN conda install -c conda-forge gperftools + +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so +ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric +ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_intel.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher + +FROM ${PLATFORM} as final ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e6cb4edf..a53c8e3b 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,4 +1,4 @@ -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL import os if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": @@ -7,7 +7,7 @@ if SYSTEM == "cuda": from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING -elif SYSTEM == "xpu": +elif IPEX_AVAIL: from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index 8b6cb87b..bfab0119 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -1,5 +1,6 @@ import intel_extension_for_pytorch as ipex import torch +from text_generation_server.models.flash_causal_lm import BLOCK_SIZE SUPPORTS_WINDOWING = False @@ -56,8 +57,6 @@ def paged_attention( input_lengths: torch.Tensor, max_s: int, ): - query = query.contiguous() - block_size = value_cache.shape[3] return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -67,7 +66,7 @@ def paged_attention( softmax_scale, block_tables, input_lengths, - block_size, + BLOCK_SIZE, max_s, None, ) diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index c4aa6c7d..93e83dfa 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -3,6 +3,7 @@ from torch import nn from accelerate import init_empty_weights from text_generation_server.utils.import_utils import ( SYSTEM, + IPEX_AVAIL, ) @@ -82,18 +83,20 @@ elif SYSTEM == "rocm": return super().forward(hidden_states), residual -elif SYSTEM == "xpu": +elif IPEX_AVAIL: import intel_extension_for_pytorch as ipex class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - res_out = hidden_states out = ipex.llm.functional.add_layer_norm( - residual, hidden_states, self.weight, self.bias, self.eps, True + residual, + hidden_states, + self.weight, + self.bias, + self.eps, + residual is not None, ) - if residual is not None: - res_out = residual - return out, res_out + return out, residual if residual is not None else hidden_states class FastRMSNorm(nn.Module): @@ -109,19 +112,16 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if SYSTEM == "xpu": - residual_out = hidden_states + if IPEX_AVAIL: out = ipex.llm.functional.add_rms_norm( residual, hidden_states, self.weight, None, self.variance_epsilon, - True, + residual is not None, ) - if residual is not None: - residual_out = residual - return out, residual_out + return out, residual if residual is not None else hidden_states elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 648d28ab..43535d83 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,14 +2,14 @@ import os import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL if SYSTEM == "cuda": from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops -elif SYSTEM == "xpu": +elif IPEX_AVAIL: import intel_extension_for_pytorch as ipex @@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "xpu": + elif IPEX_AVAIL: ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 94cf6452..56292250 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,9 +20,9 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import IPEX_AVAIL -if SYSTEM != "xpu": +if not IPEX_AVAIL: from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3900bf73..6ea95411 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,9 +24,9 @@ import torch.distributed import numpy as np from torch import nn -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import IPEX_AVAIL -if SYSTEM != "xpu": +if not IPEX_AVAIL: from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d3710..633b066b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK @@ -773,21 +773,38 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] + if IPEX_AVAIL and SYSTEM == "cpu": + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index ef129e92..43f374e5 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -37,6 +37,9 @@ class FlashGPT2(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index c5cbd2b8..f3db832a 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -17,7 +17,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL class FlashLlama(FlashCausalLM): @@ -37,6 +37,9 @@ class FlashLlama(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 081c2e2c..fa084df5 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -16,7 +16,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -41,6 +41,9 @@ class BaseFlashMistral(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index adefaeb2..a6327e0e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -36,6 +36,9 @@ class FlashNeoXSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e6350611..ff4ff147 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -37,6 +37,9 @@ class FlashRWSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 2ad36b93..7b7add84 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -40,6 +40,9 @@ class FlashSantacoderSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index d79e36c2..f1fd8f79 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,14 +1,14 @@ import torch from loguru import logger +from text_generation_server.utils.dist import WORLD_SIZE -def is_xpu_available(): +def is_ipex_available(): try: import intel_extension_for_pytorch except ImportError: return False - - return hasattr(torch, "xpu") and torch.xpu.is_available() + return True def get_cuda_free_memory(device, memory_fraction): @@ -24,6 +24,15 @@ def get_xpu_free_memory(device, memory_fraction): return free_memory +def get_cpu_free_memory(device, memory_fraction): + import psutil + + mem = psutil.virtual_memory() + free_memory = int(mem.available * 0.95 / WORLD_SIZE) + return free_memory + + +IPEX_AVAIL = is_ipex_available() SYSTEM = None if torch.version.hip is not None: SYSTEM = "rocm" @@ -35,7 +44,7 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory -elif is_xpu_available(): +elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available(): SYSTEM = "xpu" empty_cache = torch.xpu.empty_cache synchronize = torch.xpu.synchronize @@ -48,5 +57,5 @@ else: empty_cache = noop synchronize = noop - get_free_memory = noop + get_free_memory = get_cpu_free_memory logger.info(f"Detected system {SYSTEM}")