Merge branch 'huggingface:main' into qwen3_moe

This commit is contained in:
Yuan Wu 2025-05-23 10:26:57 +08:00 committed by GitHub
commit 45d95bdccc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
83 changed files with 1739 additions and 802 deletions

View File

@ -21,7 +21,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14 - uses: cachix/cachix-action@v14
with: with:
name: text-generation-inference name: huggingface
# If you chose signing key for write access # If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env: env:

View File

@ -20,7 +20,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14 - uses: cachix/cachix-action@v14
with: with:
name: text-generation-inference name: huggingface
# If you chose signing key for write access # If you chose signing key for write access
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
env: env:

View File

@ -25,7 +25,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14 - uses: cachix/cachix-action@v14
with: with:
name: text-generation-inference name: huggingface
# If you chose signing key for write access # If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env: env:

16
Cargo.lock generated
View File

@ -4650,7 +4650,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-backends-trtllm" name = "text-generation-backends-trtllm"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"clap 4.5.32", "clap 4.5.32",
@ -4671,7 +4671,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap 4.5.32", "clap 4.5.32",
@ -4691,7 +4691,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -4709,7 +4709,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"clap 4.5.32", "clap 4.5.32",
"ctrlc", "ctrlc",
@ -4730,7 +4730,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
@ -4782,7 +4782,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-llamacpp" name = "text-generation-router-llamacpp"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bindgen 0.71.1", "bindgen 0.71.1",
@ -4800,7 +4800,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v2" name = "text-generation-router-v2"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4849,7 +4849,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v3" name = "text-generation-router-v3"
version = "3.3.0-dev0" version = "3.3.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",

View File

@ -21,7 +21,7 @@ default-members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "3.3.0-dev0" version = "3.3.1-dev0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN . .venv/bin/activate && make build-awq RUN . .venv/bin/activate && make build-awq
# Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder # Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder # Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages

View File

@ -6,7 +6,7 @@
FROM nixos/nix:2.18.8 AS builder FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix RUN nix profile install nixpkgs#cachix
RUN cachix use text-generation-inference RUN cachix use huggingface
WORKDIR /root WORKDIR /root
ADD . . ADD . .
RUN nix build . RUN nix build .

View File

@ -1,5 +1,5 @@
# Those arguments are required to build the image # Those arguments are required to build the image
ARG HABANA_VERSION=1.20.0 ARG HABANA_VERSION=1.21.0
ARG PYTORCH_VERSION=2.6.0 ARG PYTORCH_VERSION=2.6.0
# Rust builder # Rust builder
@ -60,6 +60,9 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor
ENV ATTENTION=default ENV ATTENTION=default
ENV PREFIX_CACHING=0 ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0 ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0
ENV VLLM_EXPONENTIAL_BUCKETING=true
# Text Generation Inference base env # Text Generation Inference base env
ENV HF_HOME=/data \ ENV HF_HOME=/data \
@ -95,7 +98,8 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router

View File

@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)
@ -256,7 +256,7 @@ Another option is to install `text-generation-inference` locally using [Nix](htt
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
be pulled from a binary cache, removing the need to build them locally. be pulled from a binary cache, removing the need to build them locally.
First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference). First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface).
Setting up the cache is important, otherwise Nix will build many of the dependencies Setting up the cache is important, otherwise Nix will build many of the dependencies
locally, which can take hours. locally, which can take hours.

View File

@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path)) mkfile_dir := $(dir $(mkfile_path))
root_dir := ${mkfile_dir}/../.. root_dir := ${mkfile_dir}/../..
HABANA_VERSION := 1.20.0 HABANA_VERSION := 1.21.0
PYTORCH_VERSION := 2.6.0 PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install

View File

@ -26,6 +26,11 @@ class Dtype(str, Enum):
bloat16 = "bfloat16" bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
@ -34,6 +39,7 @@ def serve(
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
@ -93,7 +99,8 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}") kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
if dtype is not None and quantize not in { if dtype is not None and quantize not in {
None, None,
"bitsandbytes", "bitsandbytes",
@ -175,6 +182,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_input_tokens, max_input_tokens,

View File

@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods. # Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.lora import ( from text_generation_server.layers.lora import (
LoraLinear, LoraLinear,
@ -27,6 +28,7 @@ __all__ = [
"TensorParallelEmbedding", "TensorParallelEmbedding",
"SpeculativeHead", "SpeculativeHead",
"LoraLinear", "LoraLinear",
"Fp8Linear",
"TensorParallelMultiAdapterLinear", "TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear", "TensorParallelAdapterRowLinear",
"load_layer_norm", "load_layer_norm",

View File

@ -10,18 +10,21 @@ from .hpu import (
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
paged_attention_mla,
) )
# KVCache needs `reshape_and_cache`, so ensure that it is defined already. # KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales from .kv_cache import KVCache, get_kv_scales, KVCompressCache
__all__ = [ __all__ = [
"attention", "attention",
"get_kv_scales", "get_kv_scales",
"paged_attention", "paged_attention",
"paged_attention_mla",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache", "KVCache",
"KVCompressCache",
"Seqlen", "Seqlen",
"HPUPagedAttentionMetadata", "HPUPagedAttentionMetadata",
"trim_seqlen_metadata", "trim_seqlen_metadata",

View File

@ -90,6 +90,8 @@ class Seqlen:
def _async_h2d_tensor_copy(source, device="hpu"): def _async_h2d_tensor_copy(source, device="hpu"):
if source is None: if source is None:
return None return None
if source.device.type == "hpu":
return source
assert source.device.type == "cpu", "Source tensor is not present in host memory!" assert source.device.type == "cpu", "Source tensor is not present in host memory!"
target = torch.empty(source.shape, dtype=source.dtype, device=device) target = torch.empty(source.shape, dtype=source.dtype, device=device)
target.copy_(source, non_blocking=True) target.copy_(source, non_blocking=True)

View File

@ -7,15 +7,66 @@ from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA
import os import os
from text_generation_server.models.globals import BLOCK_SIZE
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks): class FP8Matmul(torch.nn.Module):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)] def __init__(self, scale_other):
else: super().__init__()
return cache.index_select(0, blocks) self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other
def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]
def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)
def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output
class FetchFromCache(torch.nn.Module):
def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv
def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
out = cache[: blocks.size(0)]
else:
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -84,6 +135,7 @@ def paged_attention(
hpu_attention_meta: HPUPagedAttentionMetadata, hpu_attention_meta: HPUPagedAttentionMetadata,
): ):
batch_size, head_num, head_size = query.shape batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa( output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size), query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key, key_cache=kv_cache.key,
@ -92,20 +144,53 @@ def paged_attention(
block_mapping=hpu_attention_meta.block_mapping, block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias, block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups, block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale, scale=softmax_scale,
matmul_qk_op=Matmul(), matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(), batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(), block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache, keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=fetch_from_cache, values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
) )
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, head_num, head_size) return output.view(batch_size, head_num, head_size)
__all__ = [ def paged_attention_mla(
"SUPPORTS_WINDOWING", query: torch.Tensor,
"attention", kv_cache: KVCache,
"paged_attention", kv_head_mapping: torch.Tensor,
] softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]

View File

@ -5,7 +5,6 @@ import torch
from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass @dataclass
@ -50,15 +49,17 @@ class KVCache:
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
## TODO FP8 kv cache support ## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = ( self.kv_cache = (
torch.zeros( torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.zeros( torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
@ -101,24 +102,89 @@ class KVCache:
key_cache, key_cache,
value_cache, value_cache,
slots, slots,
kv_scales.key_scale_cpu, kv_scales.key_scale,
kv_scales.value_scale_cpu, kv_scales.value_scale,
) )
class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""
kv_cache: torch.Tensor
def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks * BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache
@property
def value(self):
"""Get the value cache."""
return self.kv_cache
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
self.kv_cache.index_copy_(0, slots, key)
def paged_reshape_and_cache( def paged_reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
k_scale: float = 1.0, k_scale: torch.Tensor,
v_scale: float = 1.0, v_scale: torch.Tensor,
): ):
block_idx = slots // BLOCK_SIZE if key_cache.dtype == torch.float8_e4m3fn:
block_offset = slots % BLOCK_SIZE key = torch.ops.hpu.cast_to_fp8_v2(
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) key, k_scale, False, False, torch.float8_e4m3fn
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) )[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
key_cache.index_copy_(0, slots, key)
value_cache.index_copy_(0, slots, value)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales: def get_kv_scales(weights: Weights, prefix: str) -> KVScales:

View File

@ -12,11 +12,151 @@ from text_generation_server.utils.weights import (
from vllm_hpu_extension.ops import scaled_fp8_quant from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn quant_dtype: torch.dtype = torch.float8_e4m3fn
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
if is_hpu_gaudi2():
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
def pad_weight(weight, block_size):
"""Pads a matrix to make its dimensions multiples of block_size."""
M, N = weight.shape[-2:]
block_size_m, block_size_n = block_size
pad_M = (block_size_m - M % block_size_m) % block_size_m
pad_N = (block_size_n - N % block_size_n) % block_size_n
if pad_M == 0 and pad_N == 0:
return weight, M, N # No padding needed
padded_weight = torch.nn.functional.pad(
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
)
return padded_weight, M, N # Return original dimensions for unpadding
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
return weight
if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
return weight[:original_M, :original_N]
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
assert len(block_size) == 2
block_size_m, block_size_n = block_size
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
weight, orig_M, orig_N = pad_weight(weight, block_size)
M, N = weight.shape[-2:]
assert weight_scale_m == M // block_size_m
assert weight_scale_n == N // block_size_n
return weight, orig_M, orig_N
def dynamic_quant(data, single_scale=False):
if single_scale:
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
else:
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
scale = scale.unsqueeze(-1)
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
data, 1.0 / scale, False, False, torch.float8_e4m3fn
)[0]
return data_fp8, scale.float()
def dequant_block_fp8_weight_naive(
weight,
weight_scale,
block_size,
dtype=torch.bfloat16,
original_M=None,
original_N=None,
do_unpad=False,
):
if weight_scale is None:
return weight
assert len(block_size) == 2
weight_shape_len = len(weight.shape)
block_size_m, block_size_n = block_size
# mul scale
if weight_shape_len == 2:
weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = False
elif weight_shape_len == 3:
fd, weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = True
else:
raise ValueError("Only support original weight shape is either 2 or 3")
if do_unpad:
dequant_weight = unpad_weight(
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
)
return dequant_weight
def apply_block_fp8_linear_hpu_dynamic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
x_fp8, x_scale = dynamic_quant(input_2d)
output = torch.ops.hpu.fp8_gemm_v2(
x_fp8,
False,
weight,
True,
None,
torch.bfloat16,
x_scale,
weight_scale,
None,
False,
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
@ -42,7 +182,7 @@ def per_tensor_dequantize(
) -> torch.Tensor: ) -> torch.Tensor:
device = tensor.device device = tensor.device
dtype = torch.bfloat16 dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: if is_hpu_gaudi2():
# dequant on cpu to avoid nan on gaudi2 # dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu") tensor = tensor.to("cpu")
@ -269,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w) return UnquantizedWeight(w)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
for p in prefixes
]
scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
for p in prefixes
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str): def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
@ -389,6 +589,22 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound = kwargs.get("scale_upper_bound", None) scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None) weight_block_size = kwargs.get("weight_block_size", None)
if weight_block_size is not None:
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
weight, scale, weight_block_size
)
weight, scale = dynamic_quant(
dequant_block_fp8_weight_naive(
weight,
scale,
weight_block_size,
original_M=orig_M,
original_N=orig_N,
do_unpad=True,
)
)
scale = scale.squeeze(-1)
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,
@ -409,25 +625,10 @@ class Fp8Linear(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None: if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437 return apply_block_fp8_linear_hpu_dynamic(
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and input, self.qweight, self.scale, self.input_scale, self.bias
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
) )
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, input,
self.input_scale, self.input_scale,

View File

@ -4,7 +4,12 @@ from typing import List, Optional, Union
import torch import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
DefaultWeightsLoader,
)
from .hpu import QuantLinear from .hpu import QuantLinear
@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str, quant_method: str,
quantize: str, quantize: str,
sym: bool, sym: bool,
modules_to_not_convert: List[str],
): ):
self.bits = bits self.bits = bits
self.desc_act = desc_act self.desc_act = desc_act
@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader):
self.quant_method = quant_method self.quant_method = quant_method
self.quantize = quantize self.quantize = quantize
self.sym = sym self.sym = sym
self.modules_to_not_convert = modules_to_not_convert
def is_layer_skipped_quantization(
self, prefix: str, modules_to_not_convert: List[str]
):
return any(module_name in prefix for module_name in modules_to_not_convert)
def get_weights(self, weights: Weights, prefix: str): def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights) self._get_gptq_params(weights)
@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights(weights, prefix)
try: try:
qweight = weights.get_tensor(f"{prefix}.qweight") qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError: except RuntimeError:
@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try: try:
qweight = torch.cat( qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader):
if self.bits != 4: if self.bits != 4:
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_row(weights, prefix)
if self.desc_act: if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False

View File

@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
from vllm_hpu_extension.kernels import rms_norm
orig_shape = hidden_states.shape
if residual is not None: if residual is not None:
residual += hidden_states.view(residual.shape) hidden_states += residual
else: residual = hidden_states
residual = hidden_states hidden_states = hidden_states.to(torch.float32)
# Note: HPUFusedRMSNorm requires 3D tensors as inputs variance = hidden_states.pow(2).mean(-1, keepdim=True)
if len(orig_shape) == 2: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
residual = residual.unsqueeze(0) return self.weight * hidden_states.to(self.weight.dtype), residual
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)

View File

@ -2,6 +2,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import os
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import ( from text_generation_server.layers.fp8 import (
@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import (
fp8_quantize, fp8_quantize,
quant_dtype, quant_dtype,
normalize_e4m3fn_to_native_float8, normalize_e4m3fn_to_native_float8,
dynamic_quant,
dequant_block_fp8_weight_naive,
) )
from text_generation_server.layers.moe.fused_moe import select_experts
try: import habana_frameworks.torch as htorch
from .unquantized import fused_moe
except Exception:
fused_moe = None
class FP8SparseMoELayer(nn.Module): class FP8SparseMoELayer(nn.Module):
@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module):
self.weight_block_size = weights.weights_loader.weight_block_size self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if self.use_ep:
n_experts = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts
else:
self.ep_offset = 0
( (
self.gate_up_proj, self.gate_up_proj,
@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module):
gate_proj_name=gate_proj_name, gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name, up_proj_name=up_proj_name,
weights=weights, weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
) )
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module):
n_experts=n_experts, n_experts=n_experts,
name=down_proj_name, name=down_proj_name,
weights=weights, weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
) )
) )
if self.weight_block_size is not None:
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.weight_block_size,
)
)
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
)
)
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
self.gate_up_proj_weight_scale.squeeze(-1),
self.down_proj_weight_scale.squeeze(-1),
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe( topk_weights, topk_ids = select_experts(
x, hidden_states=x,
w1=self.gate_up_proj, router_logits=gating_output,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None, use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group, top_k=self.topk,
renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.n_expert_group,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
) )
total_num_experts = gating_output.size(-1)
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
if self.use_ep:
moe_n_slice = 1
n_expert_slice = (
total_num_experts + self.world_size - 1
) // self.world_size
else:
moe_n_slice = 1
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
w13_list_slice = [
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
]
w2_list_slice = [
self.down_proj[j, ...] for j in range(min_expert, max_expert)
]
w13_weight_scale = [
self.gate_up_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
w2_weight_scale = [
self.down_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
current_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=x_fp8,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
d_scale_hidden_states=x_scale,
d_scale_w12=w13_weight_scale,
d_scale_w3=w2_weight_scale,
permuted_weights=True,
activation="silu",
experts_min=min_expert + self.ep_offset,
experts_max=max_expert + self.ep_offset - 1,
)
htorch.core.mark_step()
if i == 0:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return final_hidden_states
def _load_expert_weights( def _load_expert_weights(
@ -98,13 +170,14 @@ def _load_expert_weights(
n_experts: int, n_experts: int,
name: str, name: str,
weights: Weights, weights: Weights,
ep_offset: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
all_weight = None all_weight = None
all_weight_scales = None all_weight_scales = None
max_input_scale = None max_input_scale = None
for i in range(n_experts): for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights) weight = get_weight_fn(prefix, i + ep_offset, name, weights)
assert isinstance(weight, Fp8Weight) assert isinstance(weight, Fp8Weight)
@ -147,14 +220,26 @@ def _load_expert_multi_weights_col(
gate_proj_name: str, gate_proj_name: str,
up_proj_name: str, up_proj_name: str,
weights: Weights, weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights): def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_multi_weights_col( return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
) )
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights( return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=None,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
) )
@ -164,10 +249,20 @@ def _load_expert_weights_row(
n_experts: int, n_experts: int,
name: str, name: str,
weights: Weights, weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights): def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}") return weights.get_weights_row(f"{prefix}.{i}.{name}")
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights(f"{prefix}.{i}.{name}")
return _load_expert_weights( return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=name,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
) )

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Tuple, Optional
import torch import torch
@ -25,12 +25,36 @@ def grouped_topk(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1) assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
gating_output = gating_output.float()
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0] num_token = scores.shape[0]
group_scores = ( if e_score_correction_bias is not None:
scores.view(num_token, num_expert_group, -1).max(dim=-1).values # Store original scores before applying correction bias. We use biased
) # [n, n_group] # scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1 1
] # [n, top_k_group] ] # [n, top_k_group]
@ -41,13 +65,19 @@ def grouped_topk(
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1) .reshape(num_token, -1)
) # [n, e] ) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_topk( def fused_topk(
@ -63,3 +93,39 @@ def fused_topk(
if renormalize: if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids

View File

@ -4,7 +4,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.utils.weights import UnquantizedWeight, Weights from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
class UnquantizedSparseMoELayer(nn.Module): class UnquantizedSparseMoELayer(nn.Module):
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
weights=weights, weights=weights,
) )
self.hpu_fused_moe = DynamicFusedMOE(n_experts) self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
for i in range(n_experts): for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.hpu_fused_moe(x, gating_output, self.topk) htorch.core.mark_step()
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(
routing_weights, self.topk, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(x.dtype)
final_hidden_states = self.MoeOp(
hidden_states=x,
expert_routing_table=selected_experts,
router_weights=routing_weights,
permuted_weights=True,
activation="silu",
)
return final_hidden_states.view(-1, x.shape[1])
def _load_expert_multi_weights_col( def _load_expert_multi_weights_col(

View File

@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
mscale_all_dim: float, mscale_all_dim: float,
): ):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
)
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
/ get_mscale(self.scaling_factor, mscale_all_dim) / get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation ) # Get n-d magnitude scaling corrected for interpolation
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,

View File

@ -360,6 +360,7 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[torch.dtype], dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
@ -485,7 +486,12 @@ def get_model(
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
kv_cache_dtype = dtype if kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
kv_cache_dtype = dtype
if FLASH_ATTENTION: if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2: if model_type == DEEPSEEK_V2:
@ -976,6 +982,7 @@ def get_model_with_lora_adapters(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[torch.dtype], dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
adapter_to_index: Dict[str, int], adapter_to_index: Dict[str, int],
@ -989,6 +996,7 @@ def get_model_with_lora_adapters(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
) )

View File

@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
class CohereRotary(PositionRotaryEmbedding): class CohereRotary(PositionRotaryEmbedding):
def forward( def forward(
@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import DynamicFusedMOE
import habana_frameworks.torch as htorch
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -28,11 +28,12 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
Fp8Linear,
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention_mla,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
) )
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -40,6 +41,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV3Config(PretrainedConfig): class DeepseekV3Config(PretrainedConfig):
@ -249,6 +263,44 @@ class DeepseekV3Attention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -261,14 +313,9 @@ class DeepseekV3Attention(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
if self.q_lora_rank is None: if self.q_lora_rank is None:
query = self.q_proj(hidden_states) hidden_states_or_q_c = hidden_states
else: else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split( compressed_kv, key_pe = torch.split(
@ -276,13 +323,18 @@ class DeepseekV3Attention(torch.nn.Module):
) )
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split( # Prefill
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 if cu_seqlen_prefill is not None:
) q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape batch_size, heads, head_dim = query_pe.shape
query_pe = ( query_pe = (
@ -297,33 +349,47 @@ class DeepseekV3Attention(torch.nn.Module):
.reshape(batch_size, heads, head_dim) .reshape(batch_size, heads, head_dim)
) )
self.rotary_emb(query_pe, key_pe, cos, sin) self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
query[..., self.qk_nope_head_dim :] = query_pe latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store( kv_cache.store(
key=key, key=latent_vec_k,
value=value, value=None,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
# Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query=query, query=query,
@ -334,9 +400,15 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen=seqlen, seqlen=seqlen,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else: else:
attn_output = paged_attention( # Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query, query,
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
@ -344,14 +416,10 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
) )
attn_output = self._v_up_proj_and_o_proj(attn_output)
# Remove padding. return attn_output
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV3MLP(nn.Module): class DeepseekV3MLP(nn.Module):
@ -584,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -596,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Gemma2Config(PretrainedConfig): class Gemma2Config(PretrainedConfig):
@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention.kv_cache import get_kv_scales
import habana_frameworks.torch as htorch
def load_qkv(config, prefix: str, weights, head_size, num_heads): def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)

View File

@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, RotaryPosEmbeddingMode,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
def load_attention(config, prefix: str, weights): def load_attention(config, prefix: str, weights):
@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -26,7 +26,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
KVCache, KVCache,
get_kv_scales, get_kv_scales,
@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module):
cross_attention_states, cross_attention_states,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
import habana_frameworks.torch as htorch
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GPTNeoXConfig(TransformersGPTNeoXConfig): class GPTNeoXConfig(TransformersGPTNeoXConfig):
@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)

View File

@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
import habana_frameworks.torch as htorch
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -18,7 +18,6 @@
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View File

@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
import habana_frameworks.torch as htorch
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module):
) )
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states = layer( hidden_states = layer(
hidden_states, hidden_states,
@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)

View File

@ -21,6 +21,7 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
) )
import habana_frameworks.torch as htorch
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
import habana_frameworks.torch as htorch
def load_multi_mqa( def load_multi_mqa(
@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -53,6 +53,7 @@ from text_generation_server.models.globals import (
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
KVCache, KVCache,
KVCompressCache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
trim_attn_metadata, trim_attn_metadata,
@ -68,11 +69,14 @@ from text_generation_server.utils.import_utils import (
synchronize, synchronize,
get_free_memory, get_free_memory,
) )
from text_generation_server.utils.prefill_chunking import (
get_max_prefill_tokens,
)
import vllm_hpu_extension.environment as environment import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import itertools import itertools
from vllm_hpu_extension.bucketing import HPUBucketingContext from vllm_hpu_extension.bucketing.common import get_bucketing_context
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -425,7 +429,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device # Create tensors on device
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
@ -628,21 +634,25 @@ class FlashCausalLMBatch(Batch):
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices] cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
if self.adapter_meta is not None:
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_indices = self.adapter_meta.adapter_indices[indices]
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) adapter_segments, adapter_segment_indices = find_segments(
adapter_meta = AdapterBatchMetadata( adapter_indices
adapter_indices=adapter_indices, )
adapter_set=adapter_set, adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
adapter_segments=adapter_segments, adapter_meta = AdapterBatchMetadata(
segment_indices=adapter_segment_indices, adapter_indices=adapter_indices,
) adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
else:
adapter_meta = None
htorch.core.mark_step() htorch.core.mark_step()
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -704,6 +714,7 @@ class FlashCausalLMBatch(Batch):
max_length = 0 max_length = 0
max_input_length = 0 max_input_length = 0
max_current_length = 0 max_current_length = 0
ADAPTER_TO_INDEX = get_adapter_to_index()
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
@ -757,14 +768,15 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
total_indices_size = sum( if ADAPTER_TO_INDEX:
b.adapter_meta.adapter_indices.shape[0] for b in batches total_indices_size = sum(
) b.adapter_meta.adapter_indices.shape[0] for b in batches
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( )
total_indices_size adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
) total_indices_size
adapter_segment_builder = SegmentConcatBuilder() )
adapter_set = set() adapter_segment_builder = SegmentConcatBuilder()
adapter_set = set()
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size total_batch_size
@ -815,9 +827,7 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + valid_bsize end_index = cumulative_batch_size + valid_bsize
index = torch.tensor( index = torch.tensor(list(range(start_index, end_index)), device="cpu")
list(range(start_index, end_index)), device=batch.input_ids.device
)
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
@ -841,7 +851,9 @@ class FlashCausalLMBatch(Batch):
) )
if not prefilling: if not prefilling:
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize]) input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_( slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots 0, index, batch.slot_indices + cumulative_slots
@ -852,20 +864,21 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor.index_copy_( cache_lengths_tensor.index_copy_(
0, index, batch.cache_lengths_tensor[:valid_bsize] 0, index, batch.cache_lengths_tensor[:valid_bsize]
) )
adapter_start_index = cumulative_adapter_indices_size if ADAPTER_TO_INDEX:
adapter_end_index = ( adapter_start_index = cumulative_adapter_indices_size
cumulative_adapter_indices_size adapter_end_index = (
+ batch.adapter_meta.adapter_indices.shape[0] cumulative_adapter_indices_size
) + batch.adapter_meta.adapter_indices.shape[0]
adapter_indices[adapter_start_index:adapter_end_index] = ( )
batch.adapter_meta.adapter_indices adapter_indices[adapter_start_index:adapter_end_index] = (
) batch.adapter_meta.adapter_indices
cumulative_adapter_indices_size = adapter_end_index )
adapter_set.update(batch.adapter_meta.adapter_set) cumulative_adapter_indices_size = adapter_end_index
adapter_segment_builder.concat( adapter_set.update(batch.adapter_meta.adapter_set)
batch.adapter_meta.adapter_segments, adapter_segment_builder.concat(
batch.adapter_meta.segment_indices, batch.adapter_meta.adapter_segments,
) batch.adapter_meta.segment_indices,
)
else: else:
if isinstance(batch.input_ids, torch.Tensor): if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist() batch.input_ids = batch.input_ids.view(-1, 1).tolist()
@ -908,7 +921,7 @@ class FlashCausalLMBatch(Batch):
else: else:
speculative_ids = None speculative_ids = None
if adapter_segment_builder is not None: if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata( adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices, adapter_indices=adapter_indices,
@ -955,7 +968,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=adapter_meta, adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
hpu_attn_meta=None, hpu_attn_meta=None,
next_token_logits=None, next_token_logits=None,
speculative_logits=None, speculative_logits=None,
@ -1031,6 +1044,7 @@ class FlashCausalLMBatch(Batch):
# need extra pad to match warmup seq # need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self) extra_pad_bs = max_padded_bs - len(self)
device = self.all_input_ids_tensor.device
if isinstance(self.input_ids, list) and len(self) > 1: if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = [] input_ids_padded_length = []
input_ids = [] input_ids = []
@ -1041,12 +1055,12 @@ class FlashCausalLMBatch(Batch):
input_ids.append(input_id) input_ids.append(input_id)
input_ids_padded_length.append(padded) input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list): elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0] input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad) input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else: else:
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
input_ids_padded_length.extend([extra_pad] * len(self)) input_ids_padded_length.extend([extra_pad] * len(self))
@ -1239,7 +1253,9 @@ class FlashCausalLMBatch(Batch):
self.slot_indices = slot_indices self.slot_indices = slot_indices
self.prefill_cu_outlens = prefill_cu_outlens self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) self.prefill_cache_indices = torch.zeros_like(
self.input_ids, dtype=torch.bool, device="cpu"
)
self.prefill_cache_indices[prefill_cache_indices] = True self.prefill_cache_indices[prefill_cache_indices] = True
if all_prefill_logprobs: if all_prefill_logprobs:
@ -1295,21 +1311,24 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states, fsm_grammar_states,
) )
if adapter_set: if ADAPTER_TO_INDEX:
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) if adapter_set:
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
else: adapter_segments, adapter_segment_indices = find_segments(
adapter_indices = torch.zeros_like(self.input_ids) adapter_indices
adapter_segments = [0, len(adapter_indices)] )
adapter_segment_indices = [len(adapter_indices) - 1] else:
adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)]
adapter_segment_indices = [len(adapter_indices) - 1]
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
self.adapter_meta = AdapterBatchMetadata( self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices, adapter_indices=adapter_indices,
adapter_set=adapter_set, adapter_set=adapter_set,
adapter_segments=adapter_segments, adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices, segment_indices=adapter_segment_indices,
) )
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -1352,6 +1371,8 @@ class FlashCausalLM(Model):
): ):
self.quantize = quantize self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if world_size > 1:
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
device = torch.device("hpu") device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
@ -1439,15 +1460,18 @@ class FlashCausalLM(Model):
self.kv_cache = [] self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None self.bucketing_ctx = None
htorch.core.hpu_set_env()
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
environment.set_model_config(self.config) environment.set_model_config(self.config)
self.use_contiguous_pa = ( self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
) )
self.limit_hpu_graphs = ( self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
) )
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
self.max_seq_len_to_capture = 8192
super().__init__( super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
@ -1479,16 +1503,27 @@ class FlashCausalLM(Model):
): ):
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
self.kv_cache = [ if self.config.model_type == "deepseek_v3":
KVCache( self.kv_cache = [
num_blocks=num_blocks, KVCompressCache(
num_heads=num_heads, num_blocks=num_blocks,
head_size=head_size, head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
else:
self.kv_cache = [
KVCache(
num_blocks=num_blocks,
num_heads=num_heads,
head_size=head_size,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
def warmup( def warmup(
self, self,
@ -1496,16 +1531,40 @@ class FlashCausalLM(Model):
max_input_tokens: Optional[int], max_input_tokens: Optional[int],
max_total_tokens: Optional[int], max_total_tokens: Optional[int],
): ):
if os.environ.get("MAX_BATCH_SIZE") is None:
raise RuntimeError(
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
"using `--max-batch-size xxx`"
)
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
self.graphed_buckets = set()
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size if self.config.model_type == "deepseek_v3":
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size cache_block_size = BLOCK_SIZE * (
self.config.kv_lora_rank + self.config.qk_rope_head_dim
)
else:
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
cache_block_size = cache_block_size * 2
total_cache_size = self.num_layers * cache_block_size * dtype_size
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
if htorch.utils.internal.is_lazy()
else 0
)
mem_used_from_graph = int(
(free_memory - self.mem_reserved) * graph_reserved_mem
)
log_master(
logger.info,
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
)
try: try:
self.init_kv_cache( self.init_kv_cache(
batch.num_blocks, batch.num_blocks,
@ -1520,15 +1579,6 @@ class FlashCausalLM(Model):
num_tokens = batch.to_pb().current_tokens num_tokens = batch.to_pb().current_tokens
synchronize(self.device) synchronize(self.device)
free_memory = get_free_memory(
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
)
real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
log_master(
logger.debug,
f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
)
_, _batch, _ = self.generate_token([batch]) _, _batch, _ = self.generate_token([batch])
except Exception: except Exception:
raise RuntimeError( raise RuntimeError(
@ -1537,8 +1587,9 @@ class FlashCausalLM(Model):
) )
synchronize(self.device) synchronize(self.device)
free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
kv_memory = free_memory
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
int(kv_memory // total_cache_size) int(kv_memory // total_cache_size)
@ -1555,7 +1606,6 @@ class FlashCausalLM(Model):
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
self.init_kv_cache( self.init_kv_cache(
num_blocks, num_blocks,
self.num_layers, self.num_layers,
@ -1564,56 +1614,177 @@ class FlashCausalLM(Model):
self.kv_cache_dtype, self.kv_cache_dtype,
self.device, self.device,
) )
self.max_batch_prefill_tokens = get_max_prefill_tokens()
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: HPUBucketingContext = get_bucketing_context()
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens) # need to warmup one more step since block is allocated from 1
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None: block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
max_total_blocks = ( max_total_tokens_aligned = math.ceil(
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1 max_total_tokens / BLOCK_SIZE
) ) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks) model_max_length = self.tokenizer.model_max_length
max_position_embeddings = getattr(
self.config, "max_position_embeddings", model_max_length
)
self.bucketing_ctx = HPUBucketingContext( self.bucketing_ctx = HPUBucketingContext(
max_num_seqs, max_num_seqs,
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO max_num_seqs, # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE, BLOCK_SIZE,
num_blocks * BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned,
False, False,
min(model_max_length, max_position_embeddings),
max_input_tokens,
max_total_tokens_aligned,
) )
self.bucketing_ctx.num_hpu_blocks = num_blocks max_blocks = max(
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
logger.info("skip warmup hpu graph, not recommmended") )
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
synchronize(self.device)
if self.skip_warmup:
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
self.bucketing_ctx.num_hpu_blocks
)
log_master(
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
)
del _batch, batch del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch) self.warmup_hpu_graph(batch)
del _batch, batch del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
phase = "Prompt" if prefilling else "Decode"
dim = "seq_len" if prefilling else "num_blocks"
graphed_bucket = (batch_size, seq_len, prefilling)
bypass = graphed_bucket not in self.graphed_buckets
msg = (
f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
f"{dim}:{seq_len} "
f"bypass:{bypass} "
f"free_mem:{free_mem}"
)
log_master(logger.info, msg)
def use_graphs(self, prefill, seq_len, batch_size):
if self.limit_hpu_graph and prefill:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, prefill) in self.graphed_buckets
def align_workers(self, value, op):
if self.world_size <= 1:
return value
value_t = torch.tensor(value, device="cpu")
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
return value_t.item()
def warmup_hpu_graph(self, batch): def warmup_hpu_graph(self, batch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) def ordering_function_min_tokens(b):
): return (b[0] * b[1], b[1], b[0])
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): buckets = list(
self.warmup_prefill(seq_len, batch_size, batch) sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
log_master( # Graph memory usage is proportional to seq dimension in a batch
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
) )
for index in range(warmup_times): if graphed_bucket in self.graphed_buckets:
self.warmup_decode(batch_size, block_num, batch) available_mem -= used_mem
synchronize(self.device) total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def warmup_prefill( def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
@ -1644,7 +1815,9 @@ class FlashCausalLM(Model):
lm_head_indices = input_lengths - 1 lm_head_indices = input_lengths - 1
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -1697,7 +1870,9 @@ class FlashCausalLM(Model):
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=_async_h2d_tensor_copy(input_ids),
@ -1780,11 +1955,11 @@ class FlashCausalLM(Model):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
seqlen = Seqlen( seqlen = Seqlen(
@ -1793,12 +1968,18 @@ class FlashCausalLM(Model):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = ( batch_size = input_lengths.shape[0]
batch.prefilling if self.limit_hpu_graphs else False prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,
@ -1837,9 +2018,7 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
speculative_ids, speculative_ids,
) = batch.next_token_chooser( ) = batch.next_token_chooser(
_async_h2d_tensor_copy( batch.all_input_ids_tensor[:, : batch.max_current_length],
batch.all_input_ids_tensor[:, : batch.max_current_length]
),
batch.next_token_logits, batch.next_token_logits,
speculate, speculate,
batch.speculative_ids, batch.speculative_ids,
@ -1853,7 +2032,6 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
) )
if batch.valid_indices is not None: if batch.valid_indices is not None:
next_input_ids = next_input_ids.cpu()
next_token_logprobs = next_token_logprobs.cpu() next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu() accepted_ids = accepted_ids.cpu()
batch.all_input_ids_tensor = batch.all_input_ids_tensor[ batch.all_input_ids_tensor = batch.all_input_ids_tensor[
@ -1895,16 +2073,16 @@ class FlashCausalLM(Model):
batch.position_ids = batch.position_ids[indices] batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices[: len(batch)]] batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
batch.adapter_meta.adapter_indices = ( if batch.adapter_meta is not None:
batch.adapter_meta.adapter_indices[indices] batch.adapter_meta.adapter_indices = (
) batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch # For each member of the batch
# Cumulative length # Cumulative length
accepted_ids = accepted_ids.cpu()
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
next_input_ids = next_input_ids.cpu()
if batch.speculative_logits is not None: if batch.speculative_logits is not None:
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
for i in range(len(batch)): for i in range(len(batch)):
batch.all_input_ids_tensor[ batch.all_input_ids_tensor[
i, i,
@ -1913,9 +2091,23 @@ class FlashCausalLM(Model):
+ batch.input_lengths[i] + batch.input_lengths[i]
+ accepted_ids[i], + accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
accepted_ids = accepted_ids.cpu()
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += accepted_ids[: len(batch)]
else: else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = index.to(batch.all_input_ids_tensor) index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange( batch_idx = torch.arange(
0, 0,
batch.all_input_ids_tensor.shape[0], batch.all_input_ids_tensor.shape[0],
@ -1925,21 +2117,18 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor.index_put_( batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids (batch_idx, index.long()), next_input_ids
) )
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.input_ids = next_input_ids
batch.position_ids += 1
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += 1
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids[: len(batch)]
# Does a HPU <-> CPU sync internally # Does a HPU <-> CPU sync internally
if prefill: if prefill and batch.adapter_meta is not None:
# adjust segment lengths to account for all request lengths being 1 during decoding # adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments( adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices batch.adapter_meta.adapter_indices
@ -2030,30 +2219,33 @@ class FlashCausalLM(Model):
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None: if adapter_meta is not None:
B, speculative_length = batch.speculative_ids.shape if batch.speculative_ids is not None:
new_length = speculative_length + 1 B, speculative_length = batch.speculative_ids.shape
adapter_indices = ( new_length = speculative_length + 1
adapter_meta.adapter_indices.unsqueeze(-1) adapter_indices = (
.expand(B, new_length) adapter_meta.adapter_indices.unsqueeze(-1)
.reshape(-1) .expand(B, new_length)
) .reshape(-1)
adapter_segments = adapter_meta.adapter_segments * new_length )
adapter_meta = AdapterBatchMetadata( adapter_segments = adapter_meta.adapter_segments * new_length
adapter_indices=adapter_indices, adapter_meta = AdapterBatchMetadata(
adapter_set=adapter_meta.adapter_set, adapter_indices=adapter_indices,
adapter_segments=adapter_segments, adapter_set=adapter_meta.adapter_set,
segment_indices=adapter_meta.segment_indices, adapter_segments=adapter_segments,
) segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights # Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed # TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta( adapter_data = AdapterBatchData.from_meta(
adapter_meta, adapter_meta,
self.layer_to_adapter_weights, self.layer_to_adapter_weights,
prefill, prefill,
batch.prefill_head_indices, batch.prefill_head_indices,
) )
else:
adapter_data = None
out, speculative_logits = self.forward(batch, adapter_data) out, speculative_logits = self.forward(batch, adapter_data)

View File

@ -23,9 +23,11 @@ from text_generation_server.layers.attention import (
_async_h2d_tensor_copy, _async_h2d_tensor_copy,
) )
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import time
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
synchronize, synchronize,
) )
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -486,20 +488,63 @@ class FlashVlmCausalLM(FlashCausalLM):
) )
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
decode_available_memory = graph_free_mem
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(decode_available_memory)} for decode "
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
# only warmup decode, for prefill, image pixal size may change, make the warmup useless # only warmup decode, for prefill, image pixal size may change, make the warmup useless
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
total_batch_seq = 0.001
total_mem = 0
available_mem = decode_available_memory
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
log_master( # Graph memory usage is proportional to seq dimension in a batch
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
) )
for index in range(warmup_times): if graphed_bucket in self.graphed_buckets:
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward( def forward(
self, self,
@ -572,14 +617,21 @@ class FlashVlmCausalLM(FlashCausalLM):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = batch.prefilling batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
@ -587,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,

View File

@ -32,6 +32,9 @@ from text_generation_server.utils.import_utils import (
) )
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
import time
import os
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -187,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
input_ids = np.concatenate(batch.input_ids, dtype=np.int64) input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else: else:
input_ids = batch.input_ids[0] input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64) batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
@ -267,6 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cross_attention_states, image_indices, input_lengths, 1, False cross_attention_states, image_indices, input_lengths, 1, False
) )
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
self.model.forward( self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
@ -280,6 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices), indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len), cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
) )
def warmup_prefill( def warmup_prefill(
@ -325,7 +334,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
self.model.forward( self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
@ -343,26 +354,103 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) def ordering_function_min_tokens(b):
): return (b[0] * b[1], b[1], b[0])
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): buckets = list(
self.warmup_prefill(seq_len, batch_size, batch) sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
graph_free_mem
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
log_master( # Graph memory usage is proportional to seq dimension in a batch
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
) )
for index in range(warmup_times): if graphed_bucket in self.graphed_buckets:
self.warmup_decode(batch_size, block_num, batch) available_mem -= used_mem
synchronize(self.device) total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward( def forward(
self, self,
@ -438,15 +526,22 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = ( batch_size = input_lengths.shape[0]
batch.prefilling if self.limit_hpu_graphs else False seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
) )
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
orig_bs = len(batch) orig_bs = len(batch)
@ -475,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,

View File

@ -206,6 +206,7 @@ def serve(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
max_input_tokens: int, max_input_tokens: int,
@ -218,6 +219,7 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if not is_driver_compatible(): if not is_driver_compatible():
@ -261,6 +263,7 @@ def serve(
quantize, quantize,
speculate, speculate,
data_type, data_type,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
adapter_to_index, adapter_to_index,
@ -308,6 +311,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
) )
) )

View File

@ -31,6 +31,7 @@ def main(args):
trust_remote_code=args.trust_remote_code, trust_remote_code=args.trust_remote_code,
uds_path=args.uds_path, uds_path=args.uds_path,
max_input_tokens=args.max_input_tokens, max_input_tokens=args.max_input_tokens,
kv_cache_dtype="auto",
) )

View File

@ -1,18 +1,9 @@
import torch import torch
from loguru import logger
def get_hpu_free_memory(device, memory_fraction): def get_hpu_free_memory(device, memory_fraction):
from habana_frameworks.torch.hpu import memory_stats free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory
device_id = device.index
mem_stats = memory_stats(device_id)
logger.info(f"mem_stats: {mem_stats}")
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
free_memory = max(
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
)
return free_memory
def synchronize_hpu(device): def synchronize_hpu(device):

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, List
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
@ -18,6 +18,8 @@ class _QuantizerConfig:
groupsize: int groupsize: int
quant_method: str quant_method: str
sym: bool sym: bool
weight_block_size: Optional[List[int]]
modules_to_not_convert: List[str]
@dataclass @dataclass
@ -25,7 +27,20 @@ class _FP8QuantizerConfig:
activation_scale_ub: float activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization, def _get_config_json(model_id: str, revision: Optional[str], filename: str):
if os.path.exists(
os.path.join(
model_id,
)
):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
return json.load(f)
# We should probably do this with Pydantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params. # but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision): def _get_quantizer_config(model_id, revision):
bits = 4 bits = 4
@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = None checkpoint_format = None
sym = False sym = False
desc_act = False desc_act = False
weight_block_size = None
modules_to_not_convert = []
filename = "config.json" filename = "config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
# FP8 config # FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8": if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig( return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"] activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
) )
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]: if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"] sym = not data["quantization_config"]["zero_point"]
@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision):
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"] quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format") checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"]["desc_act"] desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
if modules_to_not_convert is None:
modules_to_not_convert = []
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"] bits = data["bits"]
groupsize = data["group_size"] groupsize = data["group_size"]
@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision):
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"] bits = data["w_bit"]
groupsize = data["q_group_size"] groupsize = data["q_group_size"]
desc_act = data["desc_act"] desc_act = data["desc_act"]
@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format=checkpoint_format, checkpoint_format=checkpoint_format,
sym=sym, sym=sym,
desc_act=desc_act, desc_act=desc_act,
weight_block_size=weight_block_size,
modules_to_not_convert=modules_to_not_convert,
) )
@ -134,6 +139,7 @@ def get_loader(
quant_method=quantizer_config.quant_method, quant_method=quantizer_config.quant_method,
quantize=quantize, quantize=quantize,
sym=quantizer_config.sym, sym=quantizer_config.sym,
modules_to_not_convert=quantizer_config.modules_to_not_convert,
) )
elif quantize == "fp8" or quantize is None: elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
@ -141,9 +147,14 @@ def get_loader(
# Since the default for the quantize config is _QuantizerConfig, # Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error # we need to add this check to not get an attribute error
activation_scale_ub = None activation_scale_ub = None
weight_block_size = quantizer_config.weight_block_size
if isinstance(quantizer_config, _FP8QuantizerConfig): if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") return HybridFP8UnquantLoader(
activation_scale_ub,
to_fp8=quantize == "fp8",
weight_block_size=weight_block_size,
)
else: else:
raise ValueError(f"Unknown quantization method: {quantize}") raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -62,6 +62,14 @@ class WeightsLoader(ABC):
""" """
... ...
@abstractmethod
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod @abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str): def get_weights_row(self, weights: "Weights", prefix: str):
""" """
@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader):
weights.get_sharded(f"{prefix}.weight", dim=1), weights.get_sharded(f"{prefix}.weight", dim=1),
) )
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
class Weights: class Weights:
def __init__( def __init__(
@ -393,6 +405,9 @@ class Weights:
def get_weights_row(self, prefix: str): def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix) return self.weights_loader.get_weights_row(self, prefix)
def get_multi_weights(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights(self, prefixes, dim)
@contextmanager @contextmanager
def use_loader(self, weights_loader: WeightsLoader): def use_loader(self, weights_loader: WeightsLoader):
""" """

View File

@ -8,6 +8,7 @@ use std::cmp::max;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
use text_generation_router::usage_stats::Env;
use text_generation_router::validation::{ use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters, ValidStoppingParameters,
@ -185,6 +186,9 @@ struct State {
/// Paged Attention Block Allocation /// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>, block_allocator: Option<BlockAllocator>,
/// indicate if it's hpu device, the hpu device needs padding to generate first token.
is_hpu_device: bool,
} }
impl State { impl State {
@ -214,6 +218,7 @@ impl State {
speculate, speculate,
support_chunking, support_chunking,
block_allocator, block_allocator,
is_hpu_device: Env::new().is_hpu_device(),
} }
} }
@ -368,6 +373,21 @@ impl State {
} }
} }
if self.is_hpu_device {
//HPU needs to pad for the prefill
max_input_length = max_input_length.max(entry.request.input_length);
let actual_prefill_tokens_for_hpu =
(batch.len() + 1) as u32 * max_input_length;
if actual_prefill_tokens_for_hpu > prefill_token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len; prefill_tokens += postfix_len;
Some(block_allocation) Some(block_allocation)

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "3.3.0-dev0" "version": "3.3.1-dev0"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \ docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
--model-id $model --model-id $model
``` ```
@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \ docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
--model-id $model --model-id $model
<text-generation-inference-launcher-arguments> <text-generation-inference-launcher-arguments>
``` ```
@ -115,7 +115,7 @@ docker run -p 8080:80 \
-e BATCH_BUCKET_SIZE=256 \ -e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \ -e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
--model-id $model \ --model-id $model \
--sharded true --num-shard 8 \ --sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \ --max-input-tokens 1024 --max-total-tokens 2048 \
@ -141,7 +141,7 @@ docker run -p 8080:80 \
-v $volume:/data \ -v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \ -e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \ -e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
--model-id $model \ --model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4 --max-total-tokens 8192 --max-batch-size 4
@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-e PROF_PATH=/tmp/hpu_profile \ -e PROF_PATH=/tmp/hpu_profile \
-e PROF_RANKS=0 \ -e PROF_RANKS=0 \
-e PROF_RECORD_SHAPES=True \ -e PROF_RECORD_SHAPES=True \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
--model-id $model --model-id $model
``` ```

View File

@ -31,7 +31,7 @@ deployment instructions in the model card:
The service is launched simply by running the text-generation-inference container with two sets of parameters: The service is launched simply by running the text-generation-inference container with two sets of parameters:
``` ```
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.0-neuron <service_parameters> docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.1-neuron <service_parameters>
``` ```
- system parameters are used to map ports, volumes and devices between the host and the service, - system parameters are used to map ports, volumes and devices between the host and the service,

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HF_TOKEN=$token \ -e HF_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 \
--model-id $model --model-id $model
``` ```

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes
``` ```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes-nf4 docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes-nf4
``` ```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize gptq docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize gptq
``` ```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \ --device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \ --ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0-rocm \ ghcr.io/huggingface/text-generation-inference:3.3.1-rocm \
--model-id $model --model-id $model
``` ```

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \ docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \ --device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \ --ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0-intel-xpu \ ghcr.io/huggingface/text-generation-inference:3.3.1-intel-xpu \
--model-id $model --cuda-graphs 0 --model-id $model --cuda-graphs 0
``` ```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \ docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \ --device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \ --ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0-intel-cpu \ ghcr.io/huggingface/text-generation-inference:3.3.1-intel-cpu \
--model-id $model --cuda-graphs 0 --model-id $model --cuda-graphs 0
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0 \ ghcr.io/huggingface/text-generation-inference:3.3.1 \
--model-id $model --model-id $model
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0 \ ghcr.io/huggingface/text-generation-inference:3.3.1 \
--model-id $model --model-id $model
``` ```
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:3.3.0 --help docker run ghcr.io/huggingface/text-generation-inference:3.3.1 --help
``` ```
</Tip> </Tip>

View File

@ -163,7 +163,7 @@ hub = {
# create Hugging Face Model Class # create Hugging Face Model Class
huggingface_model = HuggingFaceModel( huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.0"), image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.1"),
env=hub, env=hub,
role=role, role=role,
) )

View File

@ -102,7 +102,7 @@
"flake-parts": "flake-parts_3", "flake-parts": "flake-parts_3",
"nix-test-runner": "nix-test-runner_3", "nix-test-runner": "nix-test-runner_3",
"nixpkgs": [ "nixpkgs": [
"tgi-nix", "hf-nix",
"nixpkgs" "nixpkgs"
], ],
"pre-commit-hooks": "pre-commit-hooks_3" "pre-commit-hooks": "pre-commit-hooks_3"
@ -579,6 +579,26 @@
"type": "github" "type": "github"
} }
}, },
"hf-nix": {
"inputs": {
"flake-compat": "flake-compat_4",
"flake-utils": "flake-utils_7",
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1747919133,
"narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
"type": "github"
},
"original": {
"owner": "huggingface",
"repo": "hf-nix",
"type": "github"
}
},
"nix-filter": { "nix-filter": {
"locked": { "locked": {
"lastModified": 1731533336, "lastModified": 1731533336,
@ -718,16 +738,16 @@
}, },
"nixpkgs_6": { "nixpkgs_6": {
"locked": { "locked": {
"lastModified": 1737453259, "lastModified": 1747820358,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk", "owner": "danieldk",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", "rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "danieldk",
"ref": "outlines-v0.1.4-tgi", "ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }
@ -836,19 +856,19 @@
"inputs": { "inputs": {
"crate2nix": "crate2nix", "crate2nix": "crate2nix",
"flake-utils": "flake-utils_6", "flake-utils": "flake-utils_6",
"hf-nix": "hf-nix",
"nix-filter": "nix-filter", "nix-filter": "nix-filter",
"nixpkgs": [ "nixpkgs": [
"tgi-nix", "hf-nix",
"nixpkgs" "nixpkgs"
], ],
"rust-overlay": "rust-overlay", "rust-overlay": "rust-overlay"
"tgi-nix": "tgi-nix"
} }
}, },
"rust-overlay": { "rust-overlay": {
"inputs": { "inputs": {
"nixpkgs": [ "nixpkgs": [
"tgi-nix", "hf-nix",
"nixpkgs" "nixpkgs"
] ]
}, },
@ -970,27 +990,6 @@
"repo": "default", "repo": "default",
"type": "github" "type": "github"
} }
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat_4",
"flake-utils": "flake-utils_7",
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1746795305,
"narHash": "sha256-4fpUT4j4w0NDKF22KvG7iGmwQTBPM5SrPEqt+N3fqF0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "359cd25f31f0f2ad2cadfbf4e180780a7a06e3c5",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "torch-2.7",
"repo": "text-generation-inference-nix",
"type": "github"
}
} }
}, },
"root": "root", "root": "root",

View File

@ -2,15 +2,15 @@
inputs = { inputs = {
crate2nix = { crate2nix = {
url = "github:nix-community/crate2nix"; url = "github:nix-community/crate2nix";
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "hf-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/torch-2.7"; hf-nix.url = "github:huggingface/hf-nix";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {
url = "github:oxalica/rust-overlay"; url = "github:oxalica/rust-overlay";
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "hf-nix/nixpkgs";
}; };
}; };
outputs = outputs =
@ -21,7 +21,7 @@
nixpkgs, nixpkgs,
flake-utils, flake-utils,
rust-overlay, rust-overlay,
tgi-nix, hf-nix,
}: }:
flake-utils.lib.eachDefaultSystem ( flake-utils.lib.eachDefaultSystem (
system: system:
@ -33,10 +33,10 @@
}; };
pkgs = import nixpkgs { pkgs = import nixpkgs {
inherit system; inherit system;
inherit (tgi-nix.lib) config; inherit (hf-nix.lib) config;
overlays = [ overlays = [
rust-overlay.overlays.default rust-overlay.overlays.default
tgi-nix.overlays.default hf-nix.overlays.default
(import nix/overlay.nix) (import nix/overlay.nix)
]; ];
}; };

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "google/gemma-3-4b-it", "model": "google/gemma-3-4b-it",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 42, "completion_tokens": 42,
"prompt_tokens": 277, "prompt_tokens": 277,

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "google/gemma-3-4b-it", "model": "google/gemma-3-4b-it",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 62, "completion_tokens": 62,
"prompt_tokens": 277, "prompt_tokens": 277,

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "google/gemma-3-4b-it", "model": "google/gemma-3-4b-it",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 67, "completion_tokens": 67,
"prompt_tokens": 277, "prompt_tokens": 277,

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "google/gemma-3-4b-it", "model": "google/gemma-3-4b-it",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 72, "completion_tokens": 72,
"prompt_tokens": 275, "prompt_tokens": 275,

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "google/gemma-3-4b-it", "model": "google/gemma-3-4b-it",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 80, "completion_tokens": 80,
"prompt_tokens": 279, "prompt_tokens": 279,

View File

@ -14,7 +14,7 @@
"id": "", "id": "",
"model": "google/gemma-3-4b-it", "model": "google/gemma-3-4b-it",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 35, "completion_tokens": 35,
"prompt_tokens": 32, "prompt_tokens": 32,

View File

@ -14,7 +14,7 @@
"id": "", "id": "",
"model": "google/gemma-3-4b-it", "model": "google/gemma-3-4b-it",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 44, "completion_tokens": 44,
"prompt_tokens": 37, "prompt_tokens": 37,

View File

@ -18,7 +18,7 @@
"id": "", "id": "",
"model": "unsloth/Llama-3.2-11B-Vision-Instruct", "model": "unsloth/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 10, "completion_tokens": 10,
"prompt_tokens": 45, "prompt_tokens": 45,
@ -44,7 +44,7 @@
"id": "", "id": "",
"model": "unsloth/Llama-3.2-11B-Vision-Instruct", "model": "unsloth/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 10, "completion_tokens": 10,
"prompt_tokens": 45, "prompt_tokens": 45,

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "unsloth/Llama-3.2-11B-Vision-Instruct", "model": "unsloth/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.3.0-dev0-native", "system_fingerprint": "3.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 10, "completion_tokens": 10,
"prompt_tokens": 45, "prompt_tokens": 45,

View File

@ -1263,7 +1263,23 @@ fn num_cuda_devices() -> Option<usize> {
let devices = match env::var("CUDA_VISIBLE_DEVICES") { let devices = match env::var("CUDA_VISIBLE_DEVICES") {
Ok(devices) => devices, Ok(devices) => devices,
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") { Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
Ok(devices) => devices, Ok(devices) => {
if devices.trim() == "all" {
// Count the number of all GPUs via nvidia-smi
let output = Command::new("nvidia-smi")
.args(["--query-gpu=uuid", "--format=csv,noheader"])
.output()
.ok()?;
String::from_utf8_lossy(&output.stdout)
.lines()
.filter(|line| !line.trim().is_empty())
.count()
.to_string()
} else {
devices
}
}
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?, Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
}, },
}; };

View File

@ -1,6 +1,7 @@
{ {
buildPythonPackage, buildPythonPackage,
poetry-core, poetry-core,
aiohttp,
huggingface-hub, huggingface-hub,
pydantic, pydantic,
}: }:
@ -15,6 +16,7 @@ buildPythonPackage {
build-system = [ poetry-core ]; build-system = [ poetry-core ];
dependencies = [ dependencies = [
aiohttp
huggingface-hub huggingface-hub
pydantic pydantic
]; ];

View File

@ -13,26 +13,26 @@ final: prev: {
( (
python-self: python-super: with python-self; { python-self: python-super: with python-self; {
# Python package override example: # Python package override example:
transformers = python-super.transformers.overrideAttrs ( #transformers = python-super.transformers.overrideAttrs (
_: _: { # _: _: {
src = final.fetchFromGitHub { # src = final.fetchFromGitHub {
owner = "huggingface"; # owner = "huggingface";
repo = "transformers"; # repo = "transformers";
rev = "v4.51.0"; # rev = "v4.51.0";
hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8="; # hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8=";
}; # };
} # }
); #);
huggingface-hub = python-super.huggingface-hub.overrideAttrs ( #huggingface-hub = python-super.huggingface-hub.overrideAttrs (
_: _: { # _: _: {
src = final.fetchFromGitHub { # src = final.fetchFromGitHub {
owner = "huggingface"; # owner = "huggingface";
repo = "huggingface_hub"; # repo = "huggingface_hub";
rev = "v0.30.0"; # rev = "v0.30.0";
hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o="; # hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o=";
}; # };
} # }
); #);
} }
) )
]; ];

View File

@ -31,7 +31,7 @@
peft, peft,
pillow, pillow,
prometheus-client, prometheus-client,
punica-kernels, punica-sgmv,
py-cpuinfo, py-cpuinfo,
pydantic, pydantic,
quantization, quantization,
@ -107,7 +107,7 @@ buildPythonPackage {
peft peft
pillow pillow
prometheus-client prometheus-client
punica-kernels punica-sgmv
py-cpuinfo py-cpuinfo
pydantic pydantic
quantization quantization

View File

@ -3,7 +3,6 @@ include Makefile-flash-att-v2
include Makefile-vllm include Makefile-vllm
include Makefile-awq include Makefile-awq
include Makefile-selective-scan include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-exllamav2 include Makefile-exllamav2
include Makefile-flashinfer include Makefile-flashinfer

View File

@ -1,12 +0,0 @@
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
build-lorax-punica:
if [ ! -d 'lorax-punica' ]; then \
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
fi
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
cd lorax-punica && git submodule update --init --recursive
cd lorax-punica/server/punica_kernels && python setup.py build
install-lorax-punica: build-lorax-punica
cd lorax-punica/server/punica_kernels && python setup.py install

View File

@ -163,6 +163,64 @@
} }
} }
}, },
{
"repo_id": "kernels-community/punica-sgmv",
"sha": "9ae1b469cb39c33df9ddd61657c6359acc423714",
"variants": {
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1",
"hash_type": "git_lfs_concat"
}
}
},
{ {
"repo_id": "kernels-community/quantization", "repo_id": "kernels-community/quantization",
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00", "sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",

View File

@ -58,6 +58,7 @@ build-backend = "setuptools.build_meta"
[tool.kernels.dependencies] [tool.kernels.dependencies]
"kernels-community/paged-attention" = ">=0.0.2" "kernels-community/paged-attention" = ">=0.0.2"
"kernels-community/moe" = ">=0.1.1" "kernels-community/moe" = ">=0.1.1"
"kernels-community/punica-sgmv" = ">=0.0.1"
"kernels-community/quantization" = ">=0.0.3" "kernels-community/quantization" = ">=0.0.3"
"kernels-community/quantization-eetq" = ">=0.0.1" "kernels-community/quantization-eetq" = ">=0.0.1"
"kernels-community/rotary" = ">=0.0.1" "kernels-community/rotary" = ">=0.0.1"

View File

@ -13,21 +13,20 @@ from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.adapters.weights import ( from text_generation_server.adapters.weights import (
AdapterBatchMetadata, AdapterBatchMetadata,
AdapterWeights, AdapterWeights,
BatchAdapterWeights, BatchAdapterWeights,
) )
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK, if SYSTEM == "cuda":
MAX_RANK_CUSTOM, punica_sgmv = load_kernel(
get_tmp_tensors, module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
orient_for_rank, )
pad_rank, else:
use_cutlass_shrink, punica_sgmv = None
has_sgmv,
)
def get_start_stop_idxs_for_rank(offset, size, rank, world_size): def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
@ -129,11 +128,13 @@ class LoraWeights(AdapterWeights):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False self._is_transposed = False
# [num_layers, hidden_size, r] # [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] weights_a = [
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
]
self._weights_a = torch.stack(weights_a) self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size] # [num_layers, r, hidden_size]
@ -244,8 +245,12 @@ class LoraWeights(AdapterWeights):
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv # pad lora ranks to be compatible with sgmv
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] lora_a_list = [
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
]
lora_b_list = [
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
]
if lora_a_list: if lora_a_list:
# update rank if it was padded # update rank if it was padded
@ -293,7 +298,7 @@ class BatchLoraWeights(BatchAdapterWeights):
def can_vectorize(self, pg: ProcessGroup) -> bool: def can_vectorize(self, pg: ProcessGroup) -> bool:
return all( return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM
for rank_data in self.rank_data.values() for rank_data in self.rank_data.values()
) )
@ -337,8 +342,8 @@ class BatchLoraWeights(BatchAdapterWeights):
) )
use_sgmv = False use_sgmv = False
if prefill or max_rank > BGMV_MAX_RANK: if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK:
if has_sgmv(): if punica_sgmv is not None:
use_sgmv = True use_sgmv = True
lora_a_ptr = torch.tensor( lora_a_ptr = torch.tensor(
[ [
@ -425,7 +430,7 @@ class BatchLoraWeights(BatchAdapterWeights):
if use_sgmv: if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices] lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors( tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device lora_a_ptr_indices.size(0), rank, device
) )
segment_starts = meta.adapter_segments[indices] segment_starts = meta.adapter_segments[indices]

View File

@ -5,14 +5,16 @@ import torch.distributed
from torch import nn from torch import nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import ( from text_generation_server.utils.import_utils import SYSTEM
add_lora_a_bgmv, from text_generation_server.utils.kernels import load_kernel
add_lora_b_bgmv,
has_sgmv, if SYSTEM == "cuda":
lora_a_sgmv_cutlass, punica_sgmv = load_kernel(
lora_b_sgmv_cutlass, module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
orient_for_rank, )
) else:
punica_sgmv = None
if TYPE_CHECKING: if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData from text_generation_server.adapters import AdapterBatchData
@ -41,7 +43,11 @@ class LoraLinear(nn.Module):
return result return result
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group): if (
punica_sgmv is not None
and data is not None
and data.can_vectorize(self.process_group)
):
# In tensor-parallel configurations, each GPU processes a specific segment of the output. # In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on # The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current # the layer type (e.g., attention vs. feed-forward layers). We define the current
@ -68,7 +74,7 @@ class LoraLinear(nn.Module):
if data.use_sgmv: if data.use_sgmv:
# Use SGMV for prefill # Use SGMV for prefill
v = lora_a_sgmv_cutlass( v = punica_sgmv.lora_a_sgmv_cutlass(
input, input,
rank_segments.tmp_shrink, rank_segments.tmp_shrink,
lora_a_ptr, lora_a_ptr,
@ -81,7 +87,7 @@ class LoraLinear(nn.Module):
if self.process_group.size() > 1: if self.process_group.size() > 1:
v = self.collect_lora_a(v) v = self.collect_lora_a(v)
lora_b_sgmv_cutlass( punica_sgmv.lora_b_sgmv_cutlass(
proj, proj,
v, v,
rank_segments.tmp_expand, rank_segments.tmp_expand,
@ -96,7 +102,7 @@ class LoraLinear(nn.Module):
(input.size(0), r), dtype=input.dtype, device=input.device (input.size(0), r), dtype=input.dtype, device=input.device
) )
# TODO: error with [-1, 0], but not [0, -1] # TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv( punica_sgmv.add_lora_a_bgmv(
v, v,
input, input,
lora_a_ptr, lora_a_ptr,
@ -107,7 +113,7 @@ class LoraLinear(nn.Module):
if self.process_group.size() > 1: if self.process_group.size() > 1:
v = self.collect_lora_a(v) v = self.collect_lora_a(v)
add_lora_b_bgmv( punica_sgmv.add_lora_b_bgmv(
proj, proj,
v, v,
lora_b_ptr, lora_b_ptr,
@ -142,7 +148,7 @@ class LoraLinear(nn.Module):
lora_a = data.lora_a[adapter_index][self.layer_id, :, :] lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :] lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0)) lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a a_out = input @ lora_a
if self.process_group.size() > 1: if self.process_group.size() > 1:

View File

@ -1,252 +0,0 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/sgmv.py
# License: Apache License Version 2.0, January 2004
import os
import warnings
from functools import lru_cache
from typing import List, Tuple
import torch
import torch.nn.functional as F
try:
import punica_kernels as _kernels
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError:
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
_kernels = None
HAS_SGMV = False
MIN_SGMV_RANK = 8
MIN_RANK_CUSTOM = 16
MAX_RANK_CUSTOM = 128
SGMV_BLOCK_SIZE = 16
BGMV_MAX_RANK = 64
def has_sgmv() -> bool:
return HAS_SGMV
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
if not has_sgmv():
return t
# tensor parallelism will result in effective rank being divided by world_size,
# so we need to scale the min rank to offset that effect
min_rank = MIN_SGMV_RANK * world_size
# if we're at or below the min rank, pad up to the min rank
# otherwise, pad to the nearest multiple of the block size
current_rank = t.size(dim)
target_rank = (
min_rank
if current_rank <= min_rank
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
)
if current_rank == target_rank:
return t
pad_size = target_rank - current_rank
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
pad = [0, 0] * t.dim()
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
pad = tuple(pad)
return F.pad(t, pad, mode="constant", value=0.0)
def use_cutlass_shrink(lora_rank: int) -> bool:
return lora_rank < MIN_RANK_CUSTOM
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
return t.transpose(0, 1)
return t
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
def add_lora_sgmv_cutlass(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.Tensor,
s_end: torch.Tensor,
layer_idx: int,
lora_rank: int,
):
"""
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H1]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
layer_idx: Layer index of the weight matrices.
"""
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
# Custom SGMV shrink only supports rank 16, 32, 64, 128
_add_lora_sgmv_cutlass_legacy(
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
)
return
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
def _add_lora_sgmv_cutlass_legacy(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
):
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
@lru_cache(maxsize=32)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
return torch.empty((size,), dtype=torch.uint8, device=device)
def get_tmp_expand_size(size: int) -> int:
return _kernels.sgmv_cutlass_tmp_size(size)
def get_tmp_tensors(
nsegments: int, lora_rank: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv()
has_sgmv_available = has_sgmv()
if use_cutlass:
tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp
elif has_sgmv_available:
return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device)
else:
tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp
def lora_a_sgmv_cutlass(
x: torch.Tensor,
tmp: torch.Tensor,
wa_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
) -> torch.Tensor:
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
else:
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
return v
def lora_b_sgmv_cutlass(
y: torch.Tensor,
v: torch.Tensor,
tmp: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
):
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
v: Shape: `[B, R]`. Temporary vector.
x: Shape: `[B, H1]`. Input vectors.
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
"""
def add_lora_a_bgmv(
v: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
def add_lora_b_bgmv(
y: torch.Tensor,
v: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
def segmented_matmul(
y: torch.Tensor,
x: torch.Tensor,
w: List[torch.Tensor],
b: List[torch.Tensor],
s_start: torch.IntTensor,
s_end: torch.IntTensor,
):
for i in range(len(w)):
if s_end[i] - s_start[i] <= 0:
continue
xi = x[s_start[i] : s_end[i]]
wi = w[i]
bi = b[i]
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)