add intel xpu support for TGI

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-03-05 17:52:53 -08:00 committed by Nicolas Patry
parent ee47973a2f
commit 49cd0ce943
13 changed files with 254 additions and 72 deletions

72
Dockerfile_intel Normal file
View File

@ -0,0 +1,72 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
ARG GIT_SHA
ARG DOCKER_LABEL
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo build --release
# Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base
USER root
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_common.txt && \
pip install ".[accelerate, peft]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image
FROM base
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -2,6 +2,7 @@ import math
import torch import torch
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
BLOCK_SIZE: int = 16 BLOCK_SIZE: int = 16
# Will be set in warmup # Will be set in warmup
@ -24,6 +25,9 @@ class CacheManager:
self.repeat_slots = repeat_slots self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if IS_XPU_SYSTEM:
x = 1
else:
x = self.block_size // element_size x = self.block_size // element_size
self.kv_cache = [ self.kv_cache = [

View File

@ -33,7 +33,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -752,7 +752,10 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
batch.blocks, batch.blocks,
@ -772,7 +775,10 @@ class FlashCausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e ) from e
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.synchronize(self.device) torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
@ -780,12 +786,18 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_free_memory, _ = torch.cuda.mem_get_info(self.device) total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
free_memory = max( free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
) )
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory *0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room

View File

@ -18,6 +18,7 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
def __init__( def __init__(
@ -33,6 +34,9 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")

View File

@ -33,6 +33,7 @@ tracer = trace.get_tracer(__name__)
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle() MEM_POOL = torch.cuda.graph_pool_handle()
@ -316,6 +317,9 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")

View File

@ -14,7 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -32,6 +32,9 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")

View File

@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -33,6 +33,9 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")

View File

@ -18,6 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -35,6 +36,9 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

View File

@ -57,6 +57,13 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60) options._timeout = timedelta(seconds=60)
else: else:
try:
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo" backend = "gloo"
options = None options = None

View File

@ -2,24 +2,29 @@ import os
import torch import torch
from loguru import logger from loguru import logger
import math
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = True
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False HAS_FLASH_ATTN_V2_ROCM = False
try:
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
except ImportError: except ImportError:
@ -40,7 +45,7 @@ try:
) )
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e: except ImportError as e:
try: try:
import flash_attn_cuda import flash_attn_cuda
except ImportError: except ImportError:
@ -80,6 +85,25 @@ def attention(
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
if IS_XPU_SYSTEM:
return torch.xpu.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None
)
if HAS_FLASH_ATTN_V2_CUDA: if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,

View File

@ -1,4 +1,13 @@
import torch import torch
def is_xpu_available():
try:
import intel_extension_for_pytorch
except ImportError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = is_xpu_available()

View File

@ -18,7 +18,7 @@ except ImportError:
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
HAS_AWQ = True HAS_AWQ = True
try: try:
@ -812,7 +812,13 @@ try:
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if IS_XPU_SYSTEM:
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
return out, residual
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
@ -858,7 +864,15 @@ try:
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: if IS_XPU_SYSTEM:
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.ops.torch_ipex.rms_norm(
hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon
)
return out[0], residual
elif hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
@ -984,11 +998,16 @@ try:
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IS_XPU_SYSTEM:
sin = sin.repeat(1, 1, 2).expand(query.shape)
cos = cos.repeat(1, 1, 2).expand(query.shape)
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key)
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
) )
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)

View File

@ -1,10 +1,10 @@
import torch import torch
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
@ -22,6 +22,8 @@ def reshape_and_cache(
from vllm import cache_ops from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif IS_XPU_SYSTEM:
torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots)
else: else:
raise ValueError("vllm is not supported on your system") raise ValueError("vllm is not supported on your system")
@ -63,7 +65,22 @@ def attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if IS_XPU_SYSTEM:
query = query.contiguous()
return torch.xpu.IpexPaged_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
block_tables,
input_lengths,
softmax_scale,
block_size,
max_s,
None
)
if use_v1: if use_v1:
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
from vllm._C import ops from vllm._C import ops