mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
add intel xpu support for TGI
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ee47973a2f
commit
49cd0ce943
72
Dockerfile_intel
Normal file
72
Dockerfile_intel
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# Rust builder
|
||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
||||||
|
FROM chef as planner
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --release
|
||||||
|
|
||||||
|
# Text Generation Inference base image for Intel
|
||||||
|
FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base
|
||||||
|
|
||||||
|
USER root
|
||||||
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
|
||||||
|
# Text Generation Inference base env
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements_common.txt && \
|
||||||
|
pip install ".[accelerate, peft]" --no-cache-dir
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# Final image
|
||||||
|
FROM base
|
||||||
|
|
||||||
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
|
CMD ["--json-output"]
|
@ -2,6 +2,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
BLOCK_SIZE: int = 16
|
BLOCK_SIZE: int = 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
@ -24,6 +25,9 @@ class CacheManager:
|
|||||||
self.repeat_slots = repeat_slots
|
self.repeat_slots = repeat_slots
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
x = 1
|
||||||
|
else:
|
||||||
x = self.block_size // element_size
|
x = self.block_size // element_size
|
||||||
|
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
|
@ -33,7 +33,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
|||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
@ -752,7 +752,10 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
try:
|
try:
|
||||||
cache_manager = set_cache_manager(
|
cache_manager = set_cache_manager(
|
||||||
batch.blocks,
|
batch.blocks,
|
||||||
@ -772,7 +775,10 @@ class FlashCausalLM(Model):
|
|||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
torch.cuda.synchronize(self.device)
|
torch.cuda.synchronize(self.device)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
torch.xpu.synchronize(self.device)
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
@ -780,12 +786,18 @@ class FlashCausalLM(Model):
|
|||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
|
||||||
free_memory = max(
|
free_memory = max(
|
||||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||||
)
|
)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
||||||
|
free_memory = int(total_gpu_memory *0.5)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("FlashModel is only available on GPU")
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
|
@ -18,6 +18,7 @@ from text_generation_server.utils import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -33,6 +34,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ tracer = trace.get_tracer(__name__)
|
|||||||
# Will be set in init
|
# Will be set in init
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
|
|
||||||
@ -316,6 +317,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -32,6 +32,9 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -33,6 +33,9 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +36,9 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
|
@ -57,6 +57,13 @@ def initialize_torch_distributed():
|
|||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=60)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
import oneccl_bindings_for_pytorch
|
||||||
|
|
||||||
|
backend = "ccl"
|
||||||
|
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
||||||
|
os.environ["CCL_WORKER_COUNT"] = str(1)
|
||||||
|
except ImportError:
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
options = None
|
options = None
|
||||||
|
|
||||||
|
@ -2,12 +2,17 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import math
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
HAS_FLASH_ATTN = True
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
raise ImportError("CUDA is not available")
|
raise ImportError("CUDA is not available")
|
||||||
|
|
||||||
@ -80,6 +85,25 @@ def attention(
|
|||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
return torch.xpu.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if HAS_FLASH_ATTN_V2_CUDA:
|
if HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
|
@ -1,4 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def is_xpu_available():
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
|
||||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||||
|
IS_XPU_SYSTEM = is_xpu_available()
|
||||||
|
@ -18,7 +18,7 @@ except ImportError:
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
@ -812,7 +812,13 @@ try:
|
|||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
if IS_XPU_SYSTEM:
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
return out, residual
|
||||||
|
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -858,7 +864,15 @@ try:
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192:
|
if IS_XPU_SYSTEM:
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
out = torch.ops.torch_ipex.rms_norm(
|
||||||
|
hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon
|
||||||
|
)
|
||||||
|
return out[0], residual
|
||||||
|
elif hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -984,11 +998,16 @@ try:
|
|||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
sin = sin.repeat(1, 1, 2).expand(query.shape)
|
||||||
|
cos = cos.repeat(1, 1, 2).expand(query.shape)
|
||||||
|
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def static(cls, config, dim, base, device):
|
def static(cls, config, dim, base, device):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
def reshape_and_cache(
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
@ -22,6 +22,8 @@ def reshape_and_cache(
|
|||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
else:
|
else:
|
||||||
raise ValueError("vllm is not supported on your system")
|
raise ValueError("vllm is not supported on your system")
|
||||||
|
|
||||||
@ -63,7 +65,22 @@ def attention(
|
|||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
if IS_XPU_SYSTEM:
|
||||||
|
query = query.contiguous()
|
||||||
|
return torch.xpu.IpexPaged_attention(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
softmax_scale,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
if use_v1:
|
if use_v1:
|
||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
Loading…
Reference in New Issue
Block a user