From 49cd0ce94379950809d30c9ecb69c7b2e1294584 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 5 Mar 2024 17:52:53 -0800 Subject: [PATCH] add intel xpu support for TGI Signed-off-by: Wang, Yi A --- Dockerfile_intel | 72 ++++++++++ .../models/cache_manager.py | 6 +- .../models/flash_causal_lm.py | 28 ++-- .../models/flash_llama.py | 4 + .../models/flash_mistral.py | 4 + .../models/flash_neox.py | 5 +- .../text_generation_server/models/flash_rw.py | 5 +- .../models/flash_santacoder.py | 4 + server/text_generation_server/utils/dist.py | 9 +- .../utils/flash_attn.py | 132 +++++++++++------- .../utils/import_utils.py | 9 ++ server/text_generation_server/utils/layers.py | 25 +++- .../utils/paged_attention.py | 23 ++- 13 files changed, 254 insertions(+), 72 deletions(-) create mode 100644 Dockerfile_intel diff --git a/Dockerfile_intel b/Dockerfile_intel new file mode 100644 index 00000000..16b059ca --- /dev/null +++ b/Dockerfile_intel @@ -0,0 +1,72 @@ +# Rust builder +FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + +# Text Generation Inference base image for Intel +FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base + +USER root +# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it +RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ + dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_common.txt && \ + pip install ".[accelerate, peft]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +# Final image +FROM base + +ENTRYPOINT ["text-generation-launcher"] +CMD ["--json-output"] diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 85e1b19b..4c65e2dd 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -2,6 +2,7 @@ import math import torch from typing import Optional, List, Tuple +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM BLOCK_SIZE: int = 16 # Will be set in warmup @@ -24,7 +25,10 @@ class CacheManager: self.repeat_slots = repeat_slots element_size = torch.tensor([], dtype=dtype).element_size() - x = self.block_size // element_size + if IS_XPU_SYSTEM: + x = 1 + else: + x = self.block_size // element_size self.kv_cache = [ ( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1189ccdd..94518b8f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -33,7 +33,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) - +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM @dataclass class FlashCausalLMBatch(Batch): @@ -752,7 +752,10 @@ class FlashCausalLM(Model): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive - torch.cuda.empty_cache() + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + torch.cuda.empty_cache() + elif IS_XPU_SYSTEM: + torch.xpu.empty_cache() try: cache_manager = set_cache_manager( batch.blocks, @@ -772,7 +775,10 @@ class FlashCausalLM(Model): f"You need to decrease `--max-batch-prefill-tokens`" ) from e - torch.cuda.synchronize(self.device) + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + torch.cuda.synchronize(self.device) + elif IS_XPU_SYSTEM: + torch.xpu.synchronize(self.device) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory @@ -780,12 +786,18 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - total_free_memory, _ = torch.cuda.mem_get_info(self.device) - total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + total_free_memory, _ = torch.cuda.mem_get_info(self.device) + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory - free_memory = max( - 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory - ) + free_memory = max( + 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory + ) + elif IS_XPU_SYSTEM: + total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory + free_memory = int(total_gpu_memory *0.5) + else: + raise NotImplementedError("FlashModel is only available on GPU") num_blocks = ( # Leave 5% for some wiggle room diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index f3578f88..f37fc542 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -18,6 +18,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM class FlashLlama(FlashCausalLM): def __init__( @@ -33,6 +34,9 @@ class FlashLlama(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 52a30b5f..47121abe 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -33,6 +33,7 @@ tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM MEM_POOL = torch.cuda.graph_pool_handle() @@ -316,6 +317,9 @@ class BaseFlashMistral(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 5a351bd7..70c978de 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) - +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +32,9 @@ class FlashNeoXSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index fc1e26bd..6eb25f22 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) - +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -33,6 +33,9 @@ class FlashRWSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 034949f9..6147398a 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,6 +18,7 @@ from text_generation_server.utils import ( Weights, ) +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -35,6 +36,9 @@ class FlashSantacoderSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index d02bfc5b..3625e6f2 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -57,7 +57,14 @@ def initialize_torch_distributed(): options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) else: - backend = "gloo" + try: + import oneccl_bindings_for_pytorch + + backend = "ccl" + if os.getenv("CCL_WORKER_COUNT", None) is None: + os.environ["CCL_WORKER_COUNT"] = str(1) + except ImportError: + backend = "gloo" options = None if WORLD_SIZE == 1: diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 45090c64..df192cc1 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -2,69 +2,74 @@ import os import torch from loguru import logger +import math -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") - -if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -is_sm8x = major == 8 and minor >= 0 -is_sm90 = major == 9 and minor == 0 - -HAS_FLASH_ATTN = False +HAS_FLASH_ATTN = True HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False -try: + +if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + HAS_FLASH_ATTN = False + HAS_FLASH_ATTN_V2_CUDA = False + HAS_FLASH_ATTN_V2_ROCM = False try: - import flash_attn_2_cuda - except ImportError: - architecture_suffix = "" - if IS_CUDA_SYSTEM: - architecture_suffix = "-cuda" + try: + import flash_attn_2_cuda + except ImportError: + architecture_suffix = "" + if IS_CUDA_SYSTEM: + architecture_suffix = "-cuda" + elif IS_ROCM_SYSTEM: + architecture_suffix = "-rocm" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM + HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM + except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e elif IS_ROCM_SYSTEM: - architecture_suffix = "-rocm" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM - HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM -except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name( + idx + ) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) - if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - elif IS_ROCM_SYSTEM: - for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name( - idx - ) and "MI250" not in torch.cuda.get_device_name(idx): - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True def attention( @@ -80,6 +85,25 @@ def attention( if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") + if IS_XPU_SYSTEM: + return torch.xpu.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None + ) + + if HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( q, diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 428c9f3e..7c0d8001 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,4 +1,13 @@ import torch +def is_xpu_available(): + try: + import intel_extension_for_pytorch + except ImportError: + return False + + return hasattr(torch, "xpu") and torch.xpu.is_available() + IS_ROCM_SYSTEM = torch.version.hip is not None IS_CUDA_SYSTEM = torch.version.cuda is not None +IS_XPU_SYSTEM = is_xpu_available() diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 69bd5e88..04ae7724 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,7 +18,7 @@ except ImportError: from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM HAS_AWQ = True try: @@ -812,7 +812,13 @@ try: class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + if IS_XPU_SYSTEM: + if residual is not None: + hidden_states += residual + residual = hidden_states + out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) + return out, residual + elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if residual is not None: hidden_states += residual residual = hidden_states @@ -858,7 +864,15 @@ try: return cls(weight, eps) def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: + if IS_XPU_SYSTEM: + if residual is not None: + hidden_states += residual + residual = hidden_states + out = torch.ops.torch_ipex.rms_norm( + hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon + ) + return out[0], residual + elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states @@ -984,11 +998,16 @@ try: # Inplace operation, updating query and key. pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) + elif IS_XPU_SYSTEM: + sin = sin.repeat(1, 1, 2).expand(query.shape) + cos = cos.repeat(1, 1, 2).expand(query.shape) + torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." ) + @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 487a3a72..0130cd0c 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,10 +1,10 @@ import torch - -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM _PARTITION_SIZE = 512 + def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, @@ -22,6 +22,8 @@ def reshape_and_cache( from vllm import cache_ops cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) + elif IS_XPU_SYSTEM: + torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots) else: raise ValueError("vllm is not supported on your system") @@ -63,7 +65,22 @@ def attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if IS_XPU_SYSTEM: + query = query.contiguous() + return torch.xpu.IpexPaged_attention( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + block_tables, + input_lengths, + softmax_scale, + block_size, + max_s, + None + ) + if use_v1: if IS_CUDA_SYSTEM: from vllm._C import ops