mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Merge branch 'huggingface:main' into qwen3_moe
This commit is contained in:
commit
45d95bdccc
2
.github/workflows/nix_build.yaml
vendored
2
.github/workflows/nix_build.yaml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
- uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: text-generation-inference
|
||||
name: huggingface
|
||||
# If you chose signing key for write access
|
||||
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
|
||||
env:
|
||||
|
2
.github/workflows/nix_cache.yaml
vendored
2
.github/workflows/nix_cache.yaml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
- uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: text-generation-inference
|
||||
name: huggingface
|
||||
# If you chose signing key for write access
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
env:
|
||||
|
2
.github/workflows/nix_tests.yaml
vendored
2
.github/workflows/nix_tests.yaml
vendored
@ -25,7 +25,7 @@ jobs:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
- uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: text-generation-inference
|
||||
name: huggingface
|
||||
# If you chose signing key for write access
|
||||
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
|
||||
env:
|
||||
|
16
Cargo.lock
generated
16
Cargo.lock
generated
@ -4650,7 +4650,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"clap 4.5.32",
|
||||
@ -4671,7 +4671,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap 4.5.32",
|
||||
@ -4691,7 +4691,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
@ -4709,7 +4709,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"clap 4.5.32",
|
||||
"ctrlc",
|
||||
@ -4730,7 +4730,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
@ -4782,7 +4782,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-llamacpp"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bindgen 0.71.1",
|
||||
@ -4800,7 +4800,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v2"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@ -4849,7 +4849,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -21,7 +21,7 @@ default-members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "3.3.0-dev0"
|
||||
version = "3.3.1-dev0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile
|
||||
# Build specific version of transformers
|
||||
RUN . .venv/bin/activate && make build-awq
|
||||
|
||||
# Build Lorax Punica kernels
|
||||
FROM kernel-builder AS lorax-punica-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-lorax-punica Makefile
|
||||
# Build specific version of transformers
|
||||
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder AS custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||
# Copy build artifacts from awq kernels builder
|
||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||
# Copy build artifacts from lorax punica kernels builder
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||
# Copy build artifacts from mamba builder
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||
|
@ -6,7 +6,7 @@
|
||||
FROM nixos/nix:2.18.8 AS builder
|
||||
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
|
||||
RUN nix profile install nixpkgs#cachix
|
||||
RUN cachix use text-generation-inference
|
||||
RUN cachix use huggingface
|
||||
WORKDIR /root
|
||||
ADD . .
|
||||
RUN nix build .
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Those arguments are required to build the image
|
||||
ARG HABANA_VERSION=1.20.0
|
||||
ARG HABANA_VERSION=1.21.0
|
||||
ARG PYTORCH_VERSION=2.6.0
|
||||
|
||||
# Rust builder
|
||||
@ -60,6 +60,9 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor
|
||||
ENV ATTENTION=default
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV PT_HPU_LAZY_MODE=1
|
||||
ENV PT_HPU_WEIGHT_SHARING=0
|
||||
ENV VLLM_EXPONENTIAL_BUCKETING=true
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
@ -95,7 +98,8 @@ RUN cd server && \
|
||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
|
||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
|
@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
@ -256,7 +256,7 @@ Another option is to install `text-generation-inference` locally using [Nix](htt
|
||||
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
|
||||
be pulled from a binary cache, removing the need to build them locally.
|
||||
|
||||
First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference).
|
||||
First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface).
|
||||
Setting up the cache is important, otherwise Nix will build many of the dependencies
|
||||
locally, which can take hours.
|
||||
|
||||
|
@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
|
||||
mkfile_dir := $(dir $(mkfile_path))
|
||||
root_dir := ${mkfile_dir}/../..
|
||||
|
||||
HABANA_VERSION := 1.20.0
|
||||
HABANA_VERSION := 1.21.0
|
||||
PYTORCH_VERSION := 2.6.0
|
||||
|
||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||
|
@ -26,6 +26,11 @@ class Dtype(str, Enum):
|
||||
bloat16 = "bfloat16"
|
||||
|
||||
|
||||
class KVCacheDtype(str, Enum):
|
||||
fp8_e4m3fn = "fp8_e4m3fn"
|
||||
fp8_e5m2 = "fp8_e5m2"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
@ -34,6 +39,7 @@ def serve(
|
||||
quantize: Optional[Quantization] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
@ -93,7 +99,8 @@ def serve(
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = "bfloat16" if dtype is None else dtype.value
|
||||
logger.info(f"quantize={quantize}")
|
||||
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
@ -175,6 +182,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
|
@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
||||
# Just to add the `load` methods.
|
||||
from text_generation_server.layers.layernorm import load_layer_norm
|
||||
from text_generation_server.layers.conv import load_conv2d
|
||||
from text_generation_server.layers.fp8 import Fp8Linear
|
||||
|
||||
from text_generation_server.layers.lora import (
|
||||
LoraLinear,
|
||||
@ -27,6 +28,7 @@ __all__ = [
|
||||
"TensorParallelEmbedding",
|
||||
"SpeculativeHead",
|
||||
"LoraLinear",
|
||||
"Fp8Linear",
|
||||
"TensorParallelMultiAdapterLinear",
|
||||
"TensorParallelAdapterRowLinear",
|
||||
"load_layer_norm",
|
||||
|
@ -10,18 +10,21 @@ from .hpu import (
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
paged_attention_mla,
|
||||
)
|
||||
|
||||
|
||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||
from .kv_cache import KVCache, get_kv_scales
|
||||
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
|
||||
|
||||
__all__ = [
|
||||
"attention",
|
||||
"get_kv_scales",
|
||||
"paged_attention",
|
||||
"paged_attention_mla",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"KVCompressCache",
|
||||
"Seqlen",
|
||||
"HPUPagedAttentionMetadata",
|
||||
"trim_seqlen_metadata",
|
||||
|
@ -90,6 +90,8 @@ class Seqlen:
|
||||
def _async_h2d_tensor_copy(source, device="hpu"):
|
||||
if source is None:
|
||||
return None
|
||||
if source.device.type == "hpu":
|
||||
return source
|
||||
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
|
||||
target = torch.empty(source.shape, dtype=source.dtype, device=device)
|
||||
target.copy_(source, non_blocking=True)
|
||||
|
@ -7,15 +7,66 @@ from vllm_hpu_extension.utils import Matmul
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
import os
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
|
||||
def fetch_from_cache(cache, blocks):
|
||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||
return cache[: blocks.size(0)]
|
||||
else:
|
||||
return cache.index_select(0, blocks)
|
||||
class FP8Matmul(torch.nn.Module):
|
||||
|
||||
def __init__(self, scale_other):
|
||||
super().__init__()
|
||||
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
|
||||
self.scale_other = scale_other
|
||||
|
||||
def quant_input(self, x, scale):
|
||||
return torch.ops.hpu.cast_to_fp8_v2(
|
||||
x, scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
|
||||
def matmul_fp8(
|
||||
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
|
||||
):
|
||||
return torch.ops.hpu.fp8_gemm_v2(
|
||||
A=x,
|
||||
trans_A=False,
|
||||
B=other,
|
||||
trans_B=False,
|
||||
D=None,
|
||||
out_dtype=out_dtype,
|
||||
A_scale_inv=scale_input_inv,
|
||||
B_scale_inv=scale_other_inv,
|
||||
bias=None,
|
||||
accumulate=False,
|
||||
)
|
||||
|
||||
def forward(self, input, other):
|
||||
qinput = self.quant_input(input, self.scale_input)
|
||||
qother = self.quant_input(other, self.scale_other)
|
||||
output = self.matmul_fp8(
|
||||
qinput,
|
||||
qother,
|
||||
out_dtype=torch.bfloat16,
|
||||
scale_input_inv=1.0 / self.scale_input,
|
||||
scale_other_inv=1.0 / self.scale_other,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class FetchFromCache(torch.nn.Module):
|
||||
|
||||
def __init__(self, scale_inv):
|
||||
super().__init__()
|
||||
self.scale_inv = scale_inv
|
||||
|
||||
def forward(self, cache, blocks):
|
||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||
out = cache[: blocks.size(0)]
|
||||
else:
|
||||
out = cache.index_select(0, blocks)
|
||||
if out.dtype == torch.float8_e4m3fn:
|
||||
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
|
||||
return out
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@ -84,6 +135,7 @@ def paged_attention(
|
||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||
):
|
||||
batch_size, head_num, head_size = query.shape
|
||||
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
||||
output = ops.flat_pa(
|
||||
query=query.view(batch_size, 1, head_num * head_size),
|
||||
key_cache=kv_cache.key,
|
||||
@ -92,20 +144,53 @@ def paged_attention(
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
block_size=BLOCK_SIZE,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=Matmul(),
|
||||
matmul_av_op=Matmul(),
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||
batch2block_matmul_op=Matmul(),
|
||||
block2batch_matmul_op=Matmul(),
|
||||
keys_fetch_func=fetch_from_cache,
|
||||
values_fetch_func=fetch_from_cache,
|
||||
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
||||
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, head_num, head_size)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
]
|
||||
def paged_attention_mla(
|
||||
query: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
seqlen: Seqlen,
|
||||
*,
|
||||
kv_scales: KVScales,
|
||||
softcap: Optional[float] = None,
|
||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||
kv_lora_rank: int = 0,
|
||||
):
|
||||
batch_size, head_num, head_size = query.shape
|
||||
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
||||
output = ops.flat_pa_mla(
|
||||
query=query,
|
||||
key_cache=kv_cache.key,
|
||||
value_cache=None,
|
||||
block_list=hpu_attention_meta.block_list,
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
block_size=BLOCK_SIZE,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||
batch2block_matmul_op=Matmul(),
|
||||
block2batch_matmul_op=Matmul(),
|
||||
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
||||
values_fetch_func=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, head_num, -1)
|
||||
|
||||
|
||||
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
|
||||
|
@ -5,7 +5,6 @@ import torch
|
||||
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from vllm_hpu_extension import cache_ops
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -50,15 +49,17 @@ class KVCache:
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
## TODO FP8 kv cache support
|
||||
if dtype is torch.float8_e5m2:
|
||||
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
(num_blocks * BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
(num_blocks * BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
@ -101,24 +102,89 @@ class KVCache:
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
kv_scales.key_scale,
|
||||
kv_scales.value_scale,
|
||||
)
|
||||
|
||||
|
||||
class KVCompressCache(KVCache):
|
||||
"""
|
||||
Key-value cache for attention layers.
|
||||
"""
|
||||
|
||||
kv_cache: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_blocks: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
## TODO FP8 kv cache support
|
||||
if dtype is torch.float8_e5m2:
|
||||
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||
|
||||
self.kv_cache = torch.zeros(
|
||||
(num_blocks * BLOCK_SIZE, 1, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Get the data type of the cache."""
|
||||
return self.kv_cache.dtype
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
"""Get the key cache."""
|
||||
|
||||
return self.kv_cache
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Get the value cache."""
|
||||
|
||||
return self.kv_cache
|
||||
|
||||
def store(
|
||||
self,
|
||||
*,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
kv_scales: KVScales,
|
||||
):
|
||||
"""Store the key and value at the given slots."""
|
||||
## TODO FP8 kv cache support
|
||||
if self.kv_cache.dtype == torch.float8_e4m3fn:
|
||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
self.kv_cache.index_copy_(0, slots, key)
|
||||
|
||||
|
||||
def paged_reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
):
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||
key, k_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
value = torch.ops.hpu.cast_to_fp8_v2(
|
||||
value, v_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
key_cache.index_copy_(0, slots, key)
|
||||
value_cache.index_copy_(0, slots, value)
|
||||
|
||||
|
||||
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||
|
@ -12,11 +12,151 @@ from text_generation_server.utils.weights import (
|
||||
|
||||
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||
import habana_frameworks.torch.utils.experimental as htexp
|
||||
|
||||
w8a8_block_fp8_matmul = None
|
||||
per_token_group_quant_fp8 = None
|
||||
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
if is_hpu_gaudi2():
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
|
||||
|
||||
def pad_weight(weight, block_size):
|
||||
"""Pads a matrix to make its dimensions multiples of block_size."""
|
||||
M, N = weight.shape[-2:]
|
||||
block_size_m, block_size_n = block_size
|
||||
pad_M = (block_size_m - M % block_size_m) % block_size_m
|
||||
pad_N = (block_size_n - N % block_size_n) % block_size_n
|
||||
|
||||
if pad_M == 0 and pad_N == 0:
|
||||
return weight, M, N # No padding needed
|
||||
padded_weight = torch.nn.functional.pad(
|
||||
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
|
||||
)
|
||||
return padded_weight, M, N # Return original dimensions for unpadding
|
||||
|
||||
|
||||
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
|
||||
"""Removes padding from the matrix to restore its original shape."""
|
||||
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
|
||||
return weight
|
||||
if keep_first_dim:
|
||||
return weight[:, :original_M, :original_N]
|
||||
else:
|
||||
return weight[:original_M, :original_N]
|
||||
|
||||
|
||||
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
|
||||
|
||||
assert len(block_size) == 2
|
||||
|
||||
block_size_m, block_size_n = block_size
|
||||
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
|
||||
|
||||
weight, orig_M, orig_N = pad_weight(weight, block_size)
|
||||
M, N = weight.shape[-2:]
|
||||
|
||||
assert weight_scale_m == M // block_size_m
|
||||
assert weight_scale_n == N // block_size_n
|
||||
|
||||
return weight, orig_M, orig_N
|
||||
|
||||
|
||||
def dynamic_quant(data, single_scale=False):
|
||||
if single_scale:
|
||||
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
|
||||
else:
|
||||
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
|
||||
scale = scale.unsqueeze(-1)
|
||||
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
|
||||
data, 1.0 / scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
return data_fp8, scale.float()
|
||||
|
||||
|
||||
def dequant_block_fp8_weight_naive(
|
||||
weight,
|
||||
weight_scale,
|
||||
block_size,
|
||||
dtype=torch.bfloat16,
|
||||
original_M=None,
|
||||
original_N=None,
|
||||
do_unpad=False,
|
||||
):
|
||||
if weight_scale is None:
|
||||
return weight
|
||||
assert len(block_size) == 2
|
||||
|
||||
weight_shape_len = len(weight.shape)
|
||||
|
||||
block_size_m, block_size_n = block_size
|
||||
|
||||
# mul scale
|
||||
if weight_shape_len == 2:
|
||||
weight_scale_m, weight_scale_n = weight_scale.shape
|
||||
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
|
||||
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
|
||||
if is_hpu_gaudi2():
|
||||
fake_weight = weight.cpu().to(dtype).to(weight.device)
|
||||
dequant_weight = fake_weight * weight_scale.to(dtype)
|
||||
else:
|
||||
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
|
||||
dequant_weight = dequant_weight.view(
|
||||
weight_scale_m * block_size_m, weight_scale_n * block_size_n
|
||||
)
|
||||
keep_first_dim = False
|
||||
elif weight_shape_len == 3:
|
||||
fd, weight_scale_m, weight_scale_n = weight_scale.shape
|
||||
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
|
||||
weight = weight.view(
|
||||
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
|
||||
)
|
||||
if is_hpu_gaudi2():
|
||||
fake_weight = weight.cpu().to(dtype).to(weight.device)
|
||||
dequant_weight = fake_weight * weight_scale.to(dtype)
|
||||
else:
|
||||
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
|
||||
dequant_weight = dequant_weight.view(
|
||||
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
|
||||
)
|
||||
keep_first_dim = True
|
||||
else:
|
||||
raise ValueError("Only support original weight shape is either 2 or 3")
|
||||
|
||||
if do_unpad:
|
||||
dequant_weight = unpad_weight(
|
||||
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
|
||||
)
|
||||
|
||||
return dequant_weight
|
||||
|
||||
|
||||
def apply_block_fp8_linear_hpu_dynamic(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
x_fp8, x_scale = dynamic_quant(input_2d)
|
||||
|
||||
output = torch.ops.hpu.fp8_gemm_v2(
|
||||
x_fp8,
|
||||
False,
|
||||
weight,
|
||||
True,
|
||||
None,
|
||||
torch.bfloat16,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||
@ -42,7 +182,7 @@ def per_tensor_dequantize(
|
||||
) -> torch.Tensor:
|
||||
device = tensor.device
|
||||
dtype = torch.bfloat16
|
||||
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
||||
if is_hpu_gaudi2():
|
||||
# dequant on cpu to avoid nan on gaudi2
|
||||
tensor = tensor.to("cpu")
|
||||
|
||||
@ -269,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=dim)
|
||||
scale = scale.to(weights.device)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
input_scale = [
|
||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
input_scale=input_scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
@ -389,6 +589,22 @@ class Fp8Linear(torch.nn.Module):
|
||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||
weight_block_size = kwargs.get("weight_block_size", None)
|
||||
|
||||
if weight_block_size is not None:
|
||||
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
|
||||
weight, scale, weight_block_size
|
||||
)
|
||||
weight, scale = dynamic_quant(
|
||||
dequant_block_fp8_weight_naive(
|
||||
weight,
|
||||
scale,
|
||||
weight_block_size,
|
||||
original_M=orig_M,
|
||||
original_N=orig_N,
|
||||
do_unpad=True,
|
||||
)
|
||||
)
|
||||
scale = scale.squeeze(-1)
|
||||
|
||||
return cls(
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
@ -409,25 +625,10 @@ class Fp8Linear(torch.nn.Module):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
# https://arxiv.org/pdf/2412.19437
|
||||
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||
# channels).
|
||||
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
self.qweight,
|
||||
scale,
|
||||
self.scale,
|
||||
self.weight_block_size,
|
||||
output_dtype=input.dtype,
|
||||
return apply_block_fp8_linear_hpu_dynamic(
|
||||
input, self.qweight, self.scale, self.input_scale, self.bias
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input,
|
||||
self.input_scale,
|
||||
|
@ -4,7 +4,12 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
DefaultWeightsLoader,
|
||||
)
|
||||
|
||||
|
||||
from .hpu import QuantLinear
|
||||
@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
modules_to_not_convert: List[str],
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
def is_layer_skipped_quantization(
|
||||
self, prefix: str, modules_to_not_convert: List[str]
|
||||
):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
self._get_gptq_params(weights)
|
||||
@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_weights(weights, prefix)
|
||||
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_weights_col_packed(
|
||||
weights, prefix, block_sizes
|
||||
)
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
if self.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_weights_row(weights, prefix)
|
||||
|
||||
if self.desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module):
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
from vllm_hpu_extension.kernels import rms_norm
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
if residual is not None:
|
||||
residual += hidden_states.view(residual.shape)
|
||||
else:
|
||||
residual = hidden_states
|
||||
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||
if len(orig_shape) == 2:
|
||||
residual = residual.unsqueeze(0)
|
||||
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||
return x.view(orig_shape), residual.view(orig_shape)
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(self.weight.dtype), residual
|
||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.fp8 import (
|
||||
@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import (
|
||||
fp8_quantize,
|
||||
quant_dtype,
|
||||
normalize_e4m3fn_to_native_float8,
|
||||
dynamic_quant,
|
||||
dequant_block_fp8_weight_naive,
|
||||
)
|
||||
|
||||
try:
|
||||
from .unquantized import fused_moe
|
||||
except Exception:
|
||||
fused_moe = None
|
||||
from text_generation_server.layers.moe.fused_moe import select_experts
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class FP8SparseMoELayer(nn.Module):
|
||||
@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module):
|
||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.world_size = weights.process_group.size()
|
||||
self.rank = weights.process_group.rank()
|
||||
self.ep_rank = self.rank
|
||||
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
|
||||
|
||||
if self.use_ep:
|
||||
n_experts = (n_experts + self.world_size - 1) // self.world_size
|
||||
self.ep_offset = self.ep_rank * n_experts
|
||||
else:
|
||||
self.ep_offset = 0
|
||||
|
||||
(
|
||||
self.gate_up_proj,
|
||||
@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module):
|
||||
gate_proj_name=gate_proj_name,
|
||||
up_proj_name=up_proj_name,
|
||||
weights=weights,
|
||||
use_ep=self.use_ep,
|
||||
ep_offset=self.ep_offset,
|
||||
)
|
||||
|
||||
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||
@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module):
|
||||
n_experts=n_experts,
|
||||
name=down_proj_name,
|
||||
weights=weights,
|
||||
use_ep=self.use_ep,
|
||||
ep_offset=self.ep_offset,
|
||||
)
|
||||
)
|
||||
if self.weight_block_size is not None:
|
||||
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
|
||||
dequant_block_fp8_weight_naive(
|
||||
self.gate_up_proj,
|
||||
self.gate_up_proj_weight_scale,
|
||||
self.weight_block_size,
|
||||
)
|
||||
)
|
||||
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
|
||||
dequant_block_fp8_weight_naive(
|
||||
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
|
||||
)
|
||||
)
|
||||
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
|
||||
self.gate_up_proj_weight_scale.squeeze(-1),
|
||||
self.down_proj_weight_scale.squeeze(-1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
w2=self.down_proj,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=gating_output,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
top_k=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.n_expert_group,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=self.gate_up_proj_weight_scale,
|
||||
w2_scale=self.down_proj_weight_scale,
|
||||
a1_scale=self.gate_up_proj_input_scale,
|
||||
a2_scale=self.down_proj_input_scale,
|
||||
)
|
||||
total_num_experts = gating_output.size(-1)
|
||||
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
|
||||
|
||||
if self.use_ep:
|
||||
moe_n_slice = 1
|
||||
n_expert_slice = (
|
||||
total_num_experts + self.world_size - 1
|
||||
) // self.world_size
|
||||
else:
|
||||
moe_n_slice = 1
|
||||
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
|
||||
for i in range(moe_n_slice):
|
||||
min_expert = i * n_expert_slice
|
||||
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
|
||||
w13_list_slice = [
|
||||
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
|
||||
]
|
||||
w2_list_slice = [
|
||||
self.down_proj[j, ...] for j in range(min_expert, max_expert)
|
||||
]
|
||||
w13_weight_scale = [
|
||||
self.gate_up_proj_weight_scale[j, ...]
|
||||
for j in range(min_expert, max_expert)
|
||||
]
|
||||
w2_weight_scale = [
|
||||
self.down_proj_weight_scale[j, ...]
|
||||
for j in range(min_expert, max_expert)
|
||||
]
|
||||
|
||||
current_hidden_states = torch.ops.hpu.mixture_of_experts(
|
||||
hidden_states=x_fp8,
|
||||
expert_routing_table=topk_ids.to(torch.int64),
|
||||
router_weights=topk_weights.to(x.dtype),
|
||||
w12=w13_list_slice,
|
||||
w3=w2_list_slice,
|
||||
d_scale_hidden_states=x_scale,
|
||||
d_scale_w12=w13_weight_scale,
|
||||
d_scale_w3=w2_weight_scale,
|
||||
permuted_weights=True,
|
||||
activation="silu",
|
||||
experts_min=min_expert + self.ep_offset,
|
||||
experts_max=max_expert + self.ep_offset - 1,
|
||||
)
|
||||
htorch.core.mark_step()
|
||||
if i == 0:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states.add_(current_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def _load_expert_weights(
|
||||
@ -98,13 +170,14 @@ def _load_expert_weights(
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight_scales = None
|
||||
max_input_scale = None
|
||||
|
||||
for i in range(n_experts):
|
||||
weight = get_weight_fn(prefix, i, name, weights)
|
||||
weight = get_weight_fn(prefix, i + ep_offset, name, weights)
|
||||
|
||||
assert isinstance(weight, Fp8Weight)
|
||||
|
||||
@ -147,14 +220,26 @@ def _load_expert_multi_weights_col(
|
||||
gate_proj_name: str,
|
||||
up_proj_name: str,
|
||||
weights: Weights,
|
||||
use_ep: bool = False,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
def get_weight_fn_sharded(prefix, i, name, weights):
|
||||
return weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_multi_weights(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||
get_weight_fn if use_ep else get_weight_fn_sharded,
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=None,
|
||||
weights=weights,
|
||||
ep_offset=ep_offset if use_ep else 0,
|
||||
)
|
||||
|
||||
|
||||
@ -164,10 +249,20 @@ def _load_expert_weights_row(
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
use_ep: bool = False,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
def get_weight_fn_sharded(prefix, i, name, weights):
|
||||
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_weights(f"{prefix}.{i}.{name}")
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||
get_weight_fn if use_ep else get_weight_fn_sharded,
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=name,
|
||||
weights=weights,
|
||||
ep_offset=ep_offset if use_ep else 0,
|
||||
)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,12 +25,36 @@ def grouped_topk(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
gating_output = gating_output.float()
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.float()
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
@ -41,13 +65,19 @@ def grouped_topk(
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def fused_topk(
|
||||
@ -63,3 +93,39 @@ def fused_topk(
|
||||
if renormalize:
|
||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
@ -4,7 +4,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
|
||||
import habana_frameworks.torch as htorch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
|
||||
for i in range(n_experts):
|
||||
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return self.hpu_fused_moe(x, gating_output, self.topk)
|
||||
htorch.core.mark_step()
|
||||
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.topk, dim=-1
|
||||
)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(x.dtype)
|
||||
|
||||
final_hidden_states = self.MoeOp(
|
||||
hidden_states=x,
|
||||
expert_routing_table=selected_experts,
|
||||
router_weights=routing_weights,
|
||||
permuted_weights=True,
|
||||
activation="silu",
|
||||
)
|
||||
|
||||
return final_hidden_states.view(-1, x.shape[1])
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
|
@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
mscale_all_dim: float,
|
||||
):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(
|
||||
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
||||
)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
||||
* self.attn_factor
|
||||
) # Get n-d magnitude scaling corrected for interpolation
|
||||
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
|
@ -360,6 +360,7 @@ def get_model(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[torch.dtype],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
@ -485,7 +486,12 @@ def get_model(
|
||||
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
kv_cache_dtype = dtype
|
||||
if kv_cache_dtype == "fp8_e4m3fn":
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
kv_cache_dtype = dtype
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
if model_type == DEEPSEEK_V2:
|
||||
@ -976,6 +982,7 @@ def get_model_with_lora_adapters(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[torch.dtype],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: Dict[str, int],
|
||||
@ -989,6 +996,7 @@ def get_model_with_lora_adapters(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import (
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class CohereRotary(PositionRotaryEmbedding):
|
||||
def forward(
|
||||
@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module):
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.weights import Weights
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -28,11 +28,12 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
Fp8Linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
paged_attention_mla,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
@ -40,6 +41,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.weights import Weights
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
|
||||
if isinstance(layer, Fp8Linear):
|
||||
eye = torch.eye(
|
||||
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
|
||||
)
|
||||
dequant_weights = layer(eye)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
|
||||
class DeepseekV3Config(PretrainedConfig):
|
||||
@ -249,6 +263,44 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.value_head_size,
|
||||
)
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||
q_nope, q_pe = (
|
||||
q_proj(x)
|
||||
.view(-1, self.num_heads, self.head_size)
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
|
||||
return self.o_proj(x)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -261,14 +313,9 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
if self.q_lora_rank is None:
|
||||
query = self.q_proj(hidden_states)
|
||||
hidden_states_or_q_c = hidden_states
|
||||
else:
|
||||
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
_, query_pe = torch.split(
|
||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, key_pe = torch.split(
|
||||
@ -276,13 +323,18 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
)
|
||||
|
||||
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||
)
|
||||
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
|
||||
|
||||
key_nope, value = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||
query = q_proj(hidden_states_or_q_c)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
query_nope, query_pe = torch.split(
|
||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
else:
|
||||
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||
|
||||
batch_size, heads, head_dim = query_pe.shape
|
||||
query_pe = (
|
||||
@ -297,33 +349,47 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
.reshape(batch_size, heads, head_dim)
|
||||
)
|
||||
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||
latent_vec_k = torch.concat(
|
||||
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
|
||||
)
|
||||
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
|
||||
|
||||
query[..., self.qk_nope_head_dim :] = query_pe
|
||||
key = torch.empty_like(query)
|
||||
key[..., : self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim :] = key_pe
|
||||
|
||||
# We need to pad the heads because Flash Attention does not support
|
||||
# qk and v with different head sizes.
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
|
||||
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
key=latent_vec_k,
|
||||
value=None,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
kv = self.kv_b_proj(kv_c_normed).view(
|
||||
-1,
|
||||
self.num_key_value_heads,
|
||||
self.qk_nope_head_dim + self.value_head_size,
|
||||
)
|
||||
|
||||
key_nope, value = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
query[..., self.qk_nope_head_dim :] = query_pe
|
||||
key = torch.empty_like(query)
|
||||
key[..., : self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim :] = key_pe
|
||||
|
||||
# We need to pad the heads because Flash Attention does not support
|
||||
# qk and v with different head sizes.
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
@ -334,9 +400,15 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
seqlen=seqlen,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
attn_output = attn_output[..., : self.value_head_size]
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||
)
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
# Decode
|
||||
query = torch.cat([query_nope, query_pe], dim=-1)
|
||||
attn_output = paged_attention_mla(
|
||||
query,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
@ -344,14 +416,10 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
)
|
||||
|
||||
# Remove padding.
|
||||
attn_output = attn_output[..., : self.value_head_size]
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||
)
|
||||
attn_output = self._v_up_proj_and_o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
@ -584,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -596,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class Gemma2Config(PretrainedConfig):
|
||||
@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -38,6 +38,7 @@ from text_generation_server.layers import (
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
|
@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights):
|
||||
@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
|
@ -26,7 +26,7 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
import habana_frameworks.torch as htorch
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
get_kv_scales,
|
||||
@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
cross_attention_states,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class MistralConfig(PretrainedConfig):
|
||||
@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
||||
@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||
|
||||
|
@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import (
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -18,7 +18,6 @@
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module):
|
||||
)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.h):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
|
@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module):
|
||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
|
@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class Starcoder2Config(PretrainedConfig):
|
||||
@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -53,6 +53,7 @@ from text_generation_server.models.globals import (
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
KVCompressCache,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
trim_attn_metadata,
|
||||
@ -68,11 +69,14 @@ from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.prefill_chunking import (
|
||||
get_max_prefill_tokens,
|
||||
)
|
||||
import vllm_hpu_extension.environment as environment
|
||||
import habana_frameworks.torch as htorch
|
||||
import itertools
|
||||
from vllm_hpu_extension.bucketing import HPUBucketingContext
|
||||
from vllm_hpu_extension.bucketing.common import get_bucketing_context
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -425,7 +429,9 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||
|
||||
@ -628,21 +634,25 @@ class FlashCausalLMBatch(Batch):
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
if self.adapter_meta is not None:
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
adapter_segments, adapter_segment_indices = find_segments(
|
||||
adapter_indices
|
||||
)
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
else:
|
||||
adapter_meta = None
|
||||
htorch.core.mark_step()
|
||||
return type(self)(
|
||||
batch_id=self.batch_id,
|
||||
@ -704,6 +714,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max_length = 0
|
||||
max_input_length = 0
|
||||
max_current_length = 0
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
@ -757,14 +768,15 @@ class FlashCausalLMBatch(Batch):
|
||||
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
total_indices_size = sum(
|
||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||
)
|
||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||
total_indices_size
|
||||
)
|
||||
adapter_segment_builder = SegmentConcatBuilder()
|
||||
adapter_set = set()
|
||||
if ADAPTER_TO_INDEX:
|
||||
total_indices_size = sum(
|
||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||
)
|
||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||
total_indices_size
|
||||
)
|
||||
adapter_segment_builder = SegmentConcatBuilder()
|
||||
adapter_set = set()
|
||||
|
||||
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
@ -815,9 +827,7 @@ class FlashCausalLMBatch(Batch):
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + valid_bsize
|
||||
|
||||
index = torch.tensor(
|
||||
list(range(start_index, end_index)), device=batch.input_ids.device
|
||||
)
|
||||
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
|
||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
@ -841,7 +851,9 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize])
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||
slot_indices.index_copy_(
|
||||
0, index, batch.slot_indices + cumulative_slots
|
||||
@ -852,20 +864,21 @@ class FlashCausalLMBatch(Batch):
|
||||
cache_lengths_tensor.index_copy_(
|
||||
0, index, batch.cache_lengths_tensor[:valid_bsize]
|
||||
)
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
adapter_end_index = (
|
||||
cumulative_adapter_indices_size
|
||||
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||
)
|
||||
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||
batch.adapter_meta.adapter_indices
|
||||
)
|
||||
cumulative_adapter_indices_size = adapter_end_index
|
||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||
adapter_segment_builder.concat(
|
||||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
if ADAPTER_TO_INDEX:
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
adapter_end_index = (
|
||||
cumulative_adapter_indices_size
|
||||
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||
)
|
||||
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||
batch.adapter_meta.adapter_indices
|
||||
)
|
||||
cumulative_adapter_indices_size = adapter_end_index
|
||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||
adapter_segment_builder.concat(
|
||||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
else:
|
||||
if isinstance(batch.input_ids, torch.Tensor):
|
||||
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||
@ -908,7 +921,7 @@ class FlashCausalLMBatch(Batch):
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
if adapter_segment_builder is not None:
|
||||
if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
|
||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
@ -955,7 +968,7 @@ class FlashCausalLMBatch(Batch):
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=adapter_meta,
|
||||
adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
|
||||
hpu_attn_meta=None,
|
||||
next_token_logits=None,
|
||||
speculative_logits=None,
|
||||
@ -1031,6 +1044,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# need extra pad to match warmup seq
|
||||
extra_pad = max_padded_input_len - self.max_input_length
|
||||
extra_pad_bs = max_padded_bs - len(self)
|
||||
device = self.all_input_ids_tensor.device
|
||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||
input_ids_padded_length = []
|
||||
input_ids = []
|
||||
@ -1041,12 +1055,12 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids.append(input_id)
|
||||
input_ids_padded_length.append(padded)
|
||||
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
elif isinstance(self.input_ids, list):
|
||||
input_ids = self.input_ids[0]
|
||||
input_ids_padded_length.append(extra_pad)
|
||||
input_ids = [0] * extra_pad + input_ids
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
else:
|
||||
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
||||
input_ids_padded_length.extend([extra_pad] * len(self))
|
||||
@ -1239,7 +1253,9 @@ class FlashCausalLMBatch(Batch):
|
||||
self.slot_indices = slot_indices
|
||||
|
||||
self.prefill_cu_outlens = prefill_cu_outlens
|
||||
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
|
||||
self.prefill_cache_indices = torch.zeros_like(
|
||||
self.input_ids, dtype=torch.bool, device="cpu"
|
||||
)
|
||||
self.prefill_cache_indices[prefill_cache_indices] = True
|
||||
|
||||
if all_prefill_logprobs:
|
||||
@ -1295,21 +1311,24 @@ class FlashCausalLMBatch(Batch):
|
||||
fsm_grammar_states,
|
||||
)
|
||||
|
||||
if adapter_set:
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
else:
|
||||
adapter_indices = torch.zeros_like(self.input_ids)
|
||||
adapter_segments = [0, len(adapter_indices)]
|
||||
adapter_segment_indices = [len(adapter_indices) - 1]
|
||||
if ADAPTER_TO_INDEX:
|
||||
if adapter_set:
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
|
||||
adapter_segments, adapter_segment_indices = find_segments(
|
||||
adapter_indices
|
||||
)
|
||||
else:
|
||||
adapter_indices = torch.zeros_like(self.input_ids)
|
||||
adapter_segments = [0, len(adapter_indices)]
|
||||
adapter_segment_indices = [len(adapter_indices) - 1]
|
||||
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
self.adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
self.adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
@ -1352,6 +1371,8 @@ class FlashCausalLM(Model):
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if world_size > 1:
|
||||
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
|
||||
|
||||
device = torch.device("hpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
@ -1439,15 +1460,18 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
self.bucketing_ctx = None
|
||||
htorch.core.hpu_set_env()
|
||||
if htorch.utils.internal.is_lazy():
|
||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
environment.set_model_config(self.config)
|
||||
self.use_contiguous_pa = (
|
||||
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
|
||||
)
|
||||
self.limit_hpu_graphs = (
|
||||
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true"
|
||||
self.limit_hpu_graph = (
|
||||
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
)
|
||||
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
|
||||
self.max_seq_len_to_capture = 8192
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
@ -1479,16 +1503,27 @@ class FlashCausalLM(Model):
|
||||
):
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
self.kv_cache = [
|
||||
KVCache(
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
if self.config.model_type == "deepseek_v3":
|
||||
self.kv_cache = [
|
||||
KVCompressCache(
|
||||
num_blocks=num_blocks,
|
||||
head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
self.kv_cache = [
|
||||
KVCache(
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def warmup(
|
||||
self,
|
||||
@ -1496,16 +1531,40 @@ class FlashCausalLM(Model):
|
||||
max_input_tokens: Optional[int],
|
||||
max_total_tokens: Optional[int],
|
||||
):
|
||||
if os.environ.get("MAX_BATCH_SIZE") is None:
|
||||
raise RuntimeError(
|
||||
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
|
||||
"using `--max-batch-size xxx`"
|
||||
)
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
self.graphed_buckets = set()
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
if self.config.model_type == "deepseek_v3":
|
||||
cache_block_size = BLOCK_SIZE * (
|
||||
self.config.kv_lora_rank + self.config.qk_rope_head_dim
|
||||
)
|
||||
else:
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
cache_block_size = cache_block_size * 2
|
||||
total_cache_size = self.num_layers * cache_block_size * dtype_size
|
||||
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
|
||||
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
|
||||
graph_reserved_mem = (
|
||||
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
|
||||
if htorch.utils.internal.is_lazy()
|
||||
else 0
|
||||
)
|
||||
mem_used_from_graph = int(
|
||||
(free_memory - self.mem_reserved) * graph_reserved_mem
|
||||
)
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
|
||||
)
|
||||
try:
|
||||
self.init_kv_cache(
|
||||
batch.num_blocks,
|
||||
@ -1520,15 +1579,6 @@ class FlashCausalLM(Model):
|
||||
|
||||
num_tokens = batch.to_pb().current_tokens
|
||||
synchronize(self.device)
|
||||
free_memory = get_free_memory(
|
||||
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
|
||||
)
|
||||
real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
log_master(
|
||||
logger.debug,
|
||||
f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
|
||||
)
|
||||
|
||||
_, _batch, _ = self.generate_token([batch])
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
@ -1537,8 +1587,9 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
synchronize(self.device)
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)
|
||||
kv_memory = free_memory
|
||||
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
|
||||
|
||||
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int(kv_memory // total_cache_size)
|
||||
@ -1555,7 +1606,6 @@ class FlashCausalLM(Model):
|
||||
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
self.init_kv_cache(
|
||||
num_blocks,
|
||||
self.num_layers,
|
||||
@ -1564,56 +1614,177 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
|
||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
|
||||
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
|
||||
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens)
|
||||
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None:
|
||||
max_total_blocks = (
|
||||
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1
|
||||
)
|
||||
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks)
|
||||
self.max_batch_prefill_tokens = get_max_prefill_tokens()
|
||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
|
||||
HPUBucketingContext = get_bucketing_context()
|
||||
# need to warmup one more step since block is allocated from 1
|
||||
block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
|
||||
max_total_tokens_aligned = math.ceil(
|
||||
max_total_tokens / BLOCK_SIZE
|
||||
) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
|
||||
model_max_length = self.tokenizer.model_max_length
|
||||
max_position_embeddings = getattr(
|
||||
self.config, "max_position_embeddings", model_max_length
|
||||
)
|
||||
|
||||
self.bucketing_ctx = HPUBucketingContext(
|
||||
max_num_seqs,
|
||||
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO
|
||||
max_num_seqs, # self.max_num_prefill_seqs, #TODO
|
||||
BLOCK_SIZE,
|
||||
num_blocks * BLOCK_SIZE,
|
||||
max_num_seqs * max_total_tokens_aligned,
|
||||
False,
|
||||
min(model_max_length, max_position_embeddings),
|
||||
max_input_tokens,
|
||||
max_total_tokens_aligned,
|
||||
)
|
||||
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||
logger.info("skip warmup hpu graph, not recommmended")
|
||||
max_blocks = max(
|
||||
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
|
||||
)
|
||||
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
|
||||
synchronize(self.device)
|
||||
if self.skip_warmup:
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
self.bucketing_ctx.generate_decode_buckets(
|
||||
self.bucketing_ctx.num_hpu_blocks
|
||||
)
|
||||
log_master(
|
||||
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
|
||||
)
|
||||
del _batch, batch
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
self.warmup_hpu_graph(batch)
|
||||
del _batch, batch
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
|
||||
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
|
||||
phase = "Prompt" if prefilling else "Decode"
|
||||
dim = "seq_len" if prefilling else "num_blocks"
|
||||
graphed_bucket = (batch_size, seq_len, prefilling)
|
||||
bypass = graphed_bucket not in self.graphed_buckets
|
||||
msg = (
|
||||
f"[Warmup][{phase}][{i+1}/{max_i}] "
|
||||
f"batch_size:{batch_size} "
|
||||
f"{dim}:{seq_len} "
|
||||
f"bypass:{bypass} "
|
||||
f"free_mem:{free_mem}"
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
|
||||
def use_graphs(self, prefill, seq_len, batch_size):
|
||||
if self.limit_hpu_graph and prefill:
|
||||
return False
|
||||
|
||||
if self.skip_warmup:
|
||||
return True
|
||||
|
||||
return (batch_size, seq_len, prefill) in self.graphed_buckets
|
||||
|
||||
def align_workers(self, value, op):
|
||||
if self.world_size <= 1:
|
||||
return value
|
||||
value_t = torch.tensor(value, device="cpu")
|
||||
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
|
||||
return value_t.item()
|
||||
|
||||
def warmup_hpu_graph(self, batch):
|
||||
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
graph_free_mem = free_mem - self.mem_reserved
|
||||
graph_free_mem = self.align_workers(
|
||||
graph_free_mem, torch.distributed.ReduceOp.MIN
|
||||
)
|
||||
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
|
||||
decode_available_memory = graph_free_mem - prompt_available_memory
|
||||
msg = (
|
||||
f"Using {format_bytes(graph_free_mem)}"
|
||||
f"/{format_bytes(free_mem)} "
|
||||
"of free device memory for HPUGraphs, "
|
||||
f"{format_bytes(prompt_available_memory)} for prompt and "
|
||||
f"{format_bytes(decode_available_memory)} for decode "
|
||||
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
|
||||
def ordering_function_min_tokens(b):
|
||||
return (b[0] * b[1], b[1], b[0])
|
||||
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||
)
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = prompt_available_memory
|
||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size * seq_len
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, seq_len, True)
|
||||
if not (
|
||||
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
|
||||
):
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = free_mem - self.mem_reserved
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, block_num, False)
|
||||
if not mem_estimate >= available_mem:
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
self.log_warmup(False, i, len(buckets), batch_size, block_num)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
)
|
||||
|
||||
def warmup_prefill(
|
||||
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
||||
@ -1644,7 +1815,9 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices = input_lengths - 1
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
True, prompt_len, batch_size
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
@ -1697,7 +1870,9 @@ class FlashCausalLM(Model):
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
False, hpu_attention_meta.block_list.shape[0], batch_size
|
||||
)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
@ -1780,11 +1955,11 @@ class FlashCausalLM(Model):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
else:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[: slots.shape[0]] = slots
|
||||
slots = slots_pad
|
||||
seqlen = Seqlen(
|
||||
@ -1793,12 +1968,18 @@ class FlashCausalLM(Model):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = (
|
||||
batch.prefilling if self.limit_hpu_graphs else False
|
||||
batch_size = input_lengths.shape[0]
|
||||
prompt_len = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, prompt_len, batch_size
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
input_ids=input_ids,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||
kv_cache=kv_cache,
|
||||
@ -1837,9 +2018,7 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
_async_h2d_tensor_copy(
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length]
|
||||
),
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
batch.next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
@ -1853,7 +2032,6 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
)
|
||||
if batch.valid_indices is not None:
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
next_token_logprobs = next_token_logprobs.cpu()
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
||||
@ -1895,16 +2073,16 @@ class FlashCausalLM(Model):
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
|
||||
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
|
||||
batch.adapter_meta.adapter_indices = (
|
||||
batch.adapter_meta.adapter_indices[indices]
|
||||
)
|
||||
if batch.adapter_meta is not None:
|
||||
batch.adapter_meta.adapter_indices = (
|
||||
batch.adapter_meta.adapter_indices[indices]
|
||||
)
|
||||
# For each member of the batch
|
||||
# Cumulative length
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
|
||||
if batch.speculative_logits is not None:
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
i,
|
||||
@ -1913,9 +2091,23 @@ class FlashCausalLM(Model):
|
||||
+ batch.input_lengths[i]
|
||||
+ accepted_ids[i],
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
if batch.position_ids.dim() == 2:
|
||||
# Qwen2_vl case:
|
||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||
else:
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += (
|
||||
batch.input_lengths_tensor + accepted_ids - 1
|
||||
)
|
||||
batch.input_lengths_tensor = torch.ones_like(
|
||||
batch.input_lengths_tensor
|
||||
)
|
||||
batch.slot_indices += accepted_ids[: len(batch)]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
index = index.to(batch.all_input_ids_tensor)
|
||||
index = index.to(batch.all_input_ids_tensor.device)
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
@ -1925,21 +2117,18 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids_tensor.index_put_(
|
||||
(batch_idx, index.long()), next_input_ids
|
||||
)
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.input_ids = next_input_ids
|
||||
batch.position_ids += 1
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||
batch.input_lengths_tensor = torch.ones_like(
|
||||
batch.input_lengths_tensor
|
||||
)
|
||||
batch.slot_indices += 1
|
||||
|
||||
batch.speculative_ids = speculative_ids
|
||||
if batch.position_ids.dim() == 2:
|
||||
# Qwen2_vl case:
|
||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||
else:
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += (
|
||||
batch.input_lengths_tensor + accepted_ids - 1
|
||||
)
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
batch.slot_indices += accepted_ids[: len(batch)]
|
||||
|
||||
# Does a HPU <-> CPU sync internally
|
||||
if prefill:
|
||||
if prefill and batch.adapter_meta is not None:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(
|
||||
batch.adapter_meta.adapter_indices
|
||||
@ -2030,30 +2219,33 @@ class FlashCausalLM(Model):
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
adapter_meta = batch.adapter_meta
|
||||
if batch.speculative_ids is not None:
|
||||
B, speculative_length = batch.speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
adapter_indices = (
|
||||
adapter_meta.adapter_indices.unsqueeze(-1)
|
||||
.expand(B, new_length)
|
||||
.reshape(-1)
|
||||
)
|
||||
adapter_segments = adapter_meta.adapter_segments * new_length
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_meta.adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_meta.segment_indices,
|
||||
)
|
||||
if adapter_meta is not None:
|
||||
if batch.speculative_ids is not None:
|
||||
B, speculative_length = batch.speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
adapter_indices = (
|
||||
adapter_meta.adapter_indices.unsqueeze(-1)
|
||||
.expand(B, new_length)
|
||||
.reshape(-1)
|
||||
)
|
||||
adapter_segments = adapter_meta.adapter_segments * new_length
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_meta.adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_meta.segment_indices,
|
||||
)
|
||||
|
||||
# Assign pointers to adapter weights
|
||||
# TODO(travis): don't update this if indices haven't changed
|
||||
adapter_data = AdapterBatchData.from_meta(
|
||||
adapter_meta,
|
||||
self.layer_to_adapter_weights,
|
||||
prefill,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
# Assign pointers to adapter weights
|
||||
# TODO(travis): don't update this if indices haven't changed
|
||||
adapter_data = AdapterBatchData.from_meta(
|
||||
adapter_meta,
|
||||
self.layer_to_adapter_weights,
|
||||
prefill,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
else:
|
||||
adapter_data = None
|
||||
|
||||
out, speculative_logits = self.forward(batch, adapter_data)
|
||||
|
||||
|
@ -23,9 +23,11 @@ from text_generation_server.layers.attention import (
|
||||
_async_h2d_tensor_copy,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
import time
|
||||
from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
)
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -486,20 +488,63 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
graph_free_mem = free_mem - self.mem_reserved
|
||||
graph_free_mem = self.align_workers(
|
||||
graph_free_mem, torch.distributed.ReduceOp.MIN
|
||||
)
|
||||
decode_available_memory = graph_free_mem
|
||||
msg = (
|
||||
f"Using {format_bytes(graph_free_mem)}"
|
||||
f"/{format_bytes(free_mem)} "
|
||||
"of free device memory for HPUGraphs, "
|
||||
f"{format_bytes(decode_available_memory)} for decode "
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
|
||||
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = decode_available_memory
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, block_num, False)
|
||||
if not mem_estimate >= available_mem:
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
self.log_warmup(False, i, len(buckets), batch_size, block_num)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -572,14 +617,21 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
||||
|
||||
batch_size = input_lengths.shape[0]
|
||||
seqlen = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, seqlen, batch_size
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
else:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[: slots.shape[0]] = slots
|
||||
slots = slots_pad
|
||||
|
||||
@ -587,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
input_ids=input_ids,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||
kv_cache=kv_cache,
|
||||
|
@ -32,6 +32,9 @@ from text_generation_server.utils.import_utils import (
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
from text_generation_server.utils.log import log_master
|
||||
import time
|
||||
import os
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -187,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = batch.input_ids[0]
|
||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||
|
||||
@ -267,6 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
cross_attention_states, image_indices, input_lengths, 1, False
|
||||
)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
False, hpu_attention_meta.block_list.shape[0], batch_size
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
@ -280,6 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
cross_attention_states=cross_attention_states,
|
||||
indices=_async_h2d_tensor_copy(indices),
|
||||
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def warmup_prefill(
|
||||
@ -325,7 +334,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
True, prompt_len, batch_size
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
@ -343,26 +354,103 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
graph_free_mem = free_mem - self.mem_reserved
|
||||
graph_free_mem = self.align_workers(
|
||||
graph_free_mem, torch.distributed.ReduceOp.MIN
|
||||
)
|
||||
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
|
||||
decode_available_memory = graph_free_mem - prompt_available_memory
|
||||
msg = (
|
||||
f"Using {format_bytes(graph_free_mem)}"
|
||||
f"/{format_bytes(free_mem)} "
|
||||
"of free device memory for HPUGraphs, "
|
||||
f"{format_bytes(prompt_available_memory)} for prompt and "
|
||||
f"{format_bytes(decode_available_memory)} for decode "
|
||||
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
|
||||
def ordering_function_min_tokens(b):
|
||||
return (b[0] * b[1], b[1], b[0])
|
||||
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||
)
|
||||
graph_free_mem
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = prompt_available_memory
|
||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size * seq_len
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, seq_len, True)
|
||||
if not (
|
||||
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
|
||||
):
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = free_mem - self.mem_reserved
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, block_num, False)
|
||||
if not mem_estimate >= available_mem:
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
self.log_warmup(False, i, len(buckets), batch_size, block_num)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -438,15 +526,22 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = (
|
||||
batch.prefilling if self.limit_hpu_graphs else False
|
||||
batch_size = input_lengths.shape[0]
|
||||
seqlen = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, seqlen, batch_size
|
||||
)
|
||||
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
else:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[: slots.shape[0]] = slots
|
||||
slots = slots_pad
|
||||
orig_bs = len(batch)
|
||||
@ -475,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
input_ids=input_ids,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||
kv_cache=kv_cache,
|
||||
|
@ -206,6 +206,7 @@ def serve(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
max_input_tokens: int,
|
||||
@ -218,6 +219,7 @@ def serve(
|
||||
quantize: Optional[str] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if not is_driver_compatible():
|
||||
@ -261,6 +263,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
data_type,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
adapter_to_index,
|
||||
@ -308,6 +311,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ def main(args):
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
uds_path=args.uds_path,
|
||||
max_input_tokens=args.max_input_tokens,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,18 +1,9 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def get_hpu_free_memory(device, memory_fraction):
|
||||
from habana_frameworks.torch.hpu import memory_stats
|
||||
|
||||
device_id = device.index
|
||||
mem_stats = memory_stats(device_id)
|
||||
logger.info(f"mem_stats: {mem_stats}")
|
||||
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||
free_memory = max(
|
||||
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||
)
|
||||
return free_memory
|
||||
free_hpu_memory, _ = torch.hpu.mem_get_info()
|
||||
return free_hpu_memory
|
||||
|
||||
|
||||
def synchronize_hpu(device):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.weights import (
|
||||
@ -18,6 +18,8 @@ class _QuantizerConfig:
|
||||
groupsize: int
|
||||
quant_method: str
|
||||
sym: bool
|
||||
weight_block_size: Optional[List[int]]
|
||||
modules_to_not_convert: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -25,7 +27,20 @@ class _FP8QuantizerConfig:
|
||||
activation_scale_ub: float
|
||||
|
||||
|
||||
# We should probably do this with Pytantic JSON deserialization,
|
||||
def _get_config_json(model_id: str, revision: Optional[str], filename: str):
|
||||
if os.path.exists(
|
||||
os.path.join(
|
||||
model_id,
|
||||
)
|
||||
):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||
with open(filename, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# We should probably do this with Pydantic JSON deserialization,
|
||||
# but for now we'll stay close to the old _set_gptq_params.
|
||||
def _get_quantizer_config(model_id, revision):
|
||||
bits = 4
|
||||
@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision):
|
||||
checkpoint_format = None
|
||||
sym = False
|
||||
desc_act = False
|
||||
weight_block_size = None
|
||||
modules_to_not_convert = []
|
||||
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data = _get_config_json(model_id, revision, filename)
|
||||
# FP8 config
|
||||
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||
return _FP8QuantizerConfig(
|
||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||
)
|
||||
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||
|
||||
if "zero_point" in data["quantization_config"]:
|
||||
sym = not data["quantization_config"]["zero_point"]
|
||||
@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision):
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
quant_method = data["quantization_config"]["quant_method"]
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||
desc_act = data["quantization_config"]["desc_act"]
|
||||
desc_act = data["quantization_config"].get("desc_act", False)
|
||||
modules_to_not_convert = data["quantization_config"].get(
|
||||
"modules_to_not_convert", []
|
||||
)
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
data = _get_config_json(model_id, revision, filename)
|
||||
bits = data["bits"]
|
||||
groupsize = data["group_size"]
|
||||
|
||||
@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision):
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
data = _get_config_json(model_id, revision, filename)
|
||||
bits = data["w_bit"]
|
||||
groupsize = data["q_group_size"]
|
||||
desc_act = data["desc_act"]
|
||||
@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision):
|
||||
checkpoint_format=checkpoint_format,
|
||||
sym=sym,
|
||||
desc_act=desc_act,
|
||||
weight_block_size=weight_block_size,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
)
|
||||
|
||||
|
||||
@ -134,6 +139,7 @@ def get_loader(
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
modules_to_not_convert=quantizer_config.modules_to_not_convert,
|
||||
)
|
||||
elif quantize == "fp8" or quantize is None:
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
@ -141,9 +147,14 @@ def get_loader(
|
||||
# Since the default for the quantize config is _QuantizerConfig,
|
||||
# we need to add this check to not get an attribute error
|
||||
activation_scale_ub = None
|
||||
weight_block_size = quantizer_config.weight_block_size
|
||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||
|
||||
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
||||
return HybridFP8UnquantLoader(
|
||||
activation_scale_ub,
|
||||
to_fp8=quantize == "fp8",
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||
|
@ -62,6 +62,14 @@ class WeightsLoader(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
"""
|
||||
Get the weights at the given prefixes, column-split them for tensor
|
||||
parallelim, and then concatenate the weights along the given dimension.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||
)
|
||||
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
|
||||
return self.weight_class(torch.cat(w, dim=dim))
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
@ -393,6 +405,9 @@ class Weights:
|
||||
def get_weights_row(self, prefix: str):
|
||||
return self.weights_loader.get_weights_row(self, prefix)
|
||||
|
||||
def get_multi_weights(self, prefixes: List[str], dim: int):
|
||||
return self.weights_loader.get_multi_weights(self, prefixes, dim)
|
||||
|
||||
@contextmanager
|
||||
def use_loader(self, weights_loader: WeightsLoader):
|
||||
"""
|
||||
|
@ -8,6 +8,7 @@ use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::usage_stats::Env;
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
@ -185,6 +186,9 @@ struct State {
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
|
||||
/// indicate if it's hpu device, the hpu device needs padding to generate first token.
|
||||
is_hpu_device: bool,
|
||||
}
|
||||
|
||||
impl State {
|
||||
@ -214,6 +218,7 @@ impl State {
|
||||
speculate,
|
||||
support_chunking,
|
||||
block_allocator,
|
||||
is_hpu_device: Env::new().is_hpu_device(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -368,6 +373,21 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_hpu_device {
|
||||
//HPU needs to pad for the prefill
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
let actual_prefill_tokens_for_hpu =
|
||||
(batch.len() + 1) as u32 * max_input_length;
|
||||
|
||||
if actual_prefill_tokens_for_hpu > prefill_token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
}
|
||||
|
||||
prefill_tokens += postfix_len;
|
||||
|
||||
Some(block_allocation)
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "3.3.0-dev0"
|
||||
"version": "3.3.1-dev0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN
|
||||
|
||||
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN
|
||||
|
||||
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
--model-id $model
|
||||
<text-generation-inference-launcher-arguments>
|
||||
```
|
||||
@ -115,7 +115,7 @@ docker run -p 8080:80 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
@ -141,7 +141,7 @@ docker run -p 8080:80 \
|
||||
-v $volume:/data \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||
-e BATCH_BUCKET_SIZE=1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||
--max-total-tokens 8192 --max-batch-size 4
|
||||
@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-e PROF_PATH=/tmp/hpu_profile \
|
||||
-e PROF_RANKS=0 \
|
||||
-e PROF_RECORD_SHAPES=True \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -31,7 +31,7 @@ deployment instructions in the model card:
|
||||
The service is launched simply by running the text-generation-inference container with two sets of parameters:
|
||||
|
||||
```
|
||||
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.0-neuron <service_parameters>
|
||||
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.1-neuron <service_parameters>
|
||||
```
|
||||
|
||||
- system parameters are used to map ports, volumes and devices between the host and the service,
|
||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HF_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes-nf4
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
|
||||
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize gptq
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize gptq
|
||||
```
|
||||
|
||||
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-intel-xpu \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-intel-xpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0-intel-cpu \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-intel-cpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:3.3.0 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:3.3.1 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -163,7 +163,7 @@ hub = {
|
||||
|
||||
# create Hugging Face Model Class
|
||||
huggingface_model = HuggingFaceModel(
|
||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.0"),
|
||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.1"),
|
||||
env=hub,
|
||||
role=role,
|
||||
)
|
||||
|
59
flake.lock
59
flake.lock
@ -102,7 +102,7 @@
|
||||
"flake-parts": "flake-parts_3",
|
||||
"nix-test-runner": "nix-test-runner_3",
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"hf-nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"pre-commit-hooks": "pre-commit-hooks_3"
|
||||
@ -579,6 +579,26 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"hf-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat_4",
|
||||
"flake-utils": "flake-utils_7",
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1747919133,
|
||||
"narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
|
||||
"owner": "huggingface",
|
||||
"repo": "hf-nix",
|
||||
"rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"repo": "hf-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix-filter": {
|
||||
"locked": {
|
||||
"lastModified": 1731533336,
|
||||
@ -718,16 +738,16 @@
|
||||
},
|
||||
"nixpkgs_6": {
|
||||
"locked": {
|
||||
"lastModified": 1737453259,
|
||||
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
|
||||
"lastModified": 1747820358,
|
||||
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
||||
"owner": "danieldk",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
|
||||
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "outlines-v0.1.4-tgi",
|
||||
"ref": "cudatoolkit-12.9-kernel-builder",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
@ -836,19 +856,19 @@
|
||||
"inputs": {
|
||||
"crate2nix": "crate2nix",
|
||||
"flake-utils": "flake-utils_6",
|
||||
"hf-nix": "hf-nix",
|
||||
"nix-filter": "nix-filter",
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"hf-nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"rust-overlay": "rust-overlay",
|
||||
"tgi-nix": "tgi-nix"
|
||||
"rust-overlay": "rust-overlay"
|
||||
}
|
||||
},
|
||||
"rust-overlay": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"hf-nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
@ -970,27 +990,6 @@
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"tgi-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat_4",
|
||||
"flake-utils": "flake-utils_7",
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1746795305,
|
||||
"narHash": "sha256-4fpUT4j4w0NDKF22KvG7iGmwQTBPM5SrPEqt+N3fqF0=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "359cd25f31f0f2ad2cadfbf4e180780a7a06e3c5",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "torch-2.7",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
14
flake.nix
14
flake.nix
@ -2,15 +2,15 @@
|
||||
inputs = {
|
||||
crate2nix = {
|
||||
url = "github:nix-community/crate2nix";
|
||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
inputs.nixpkgs.follows = "hf-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/torch-2.7";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
hf-nix.url = "github:huggingface/hf-nix";
|
||||
nixpkgs.follows = "hf-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
url = "github:oxalica/rust-overlay";
|
||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
inputs.nixpkgs.follows = "hf-nix/nixpkgs";
|
||||
};
|
||||
};
|
||||
outputs =
|
||||
@ -21,7 +21,7 @@
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
rust-overlay,
|
||||
tgi-nix,
|
||||
hf-nix,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
@ -33,10 +33,10 @@
|
||||
};
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
inherit (tgi-nix.lib) config;
|
||||
inherit (hf-nix.lib) config;
|
||||
overlays = [
|
||||
rust-overlay.overlays.default
|
||||
tgi-nix.overlays.default
|
||||
hf-nix.overlays.default
|
||||
(import nix/overlay.nix)
|
||||
];
|
||||
};
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 42,
|
||||
"prompt_tokens": 277,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 62,
|
||||
"prompt_tokens": 277,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 67,
|
||||
"prompt_tokens": 277,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 72,
|
||||
"prompt_tokens": 275,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 80,
|
||||
"prompt_tokens": 279,
|
||||
|
@ -14,7 +14,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 35,
|
||||
"prompt_tokens": 32,
|
||||
|
@ -14,7 +14,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 44,
|
||||
"prompt_tokens": 37,
|
||||
|
@ -18,7 +18,7 @@
|
||||
"id": "",
|
||||
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 45,
|
||||
@ -44,7 +44,7 @@
|
||||
"id": "",
|
||||
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 45,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.0-dev0-native",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 45,
|
||||
|
@ -1263,7 +1263,23 @@ fn num_cuda_devices() -> Option<usize> {
|
||||
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
|
||||
Ok(devices) => devices,
|
||||
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
|
||||
Ok(devices) => devices,
|
||||
Ok(devices) => {
|
||||
if devices.trim() == "all" {
|
||||
// Count the number of all GPUs via nvidia-smi
|
||||
let output = Command::new("nvidia-smi")
|
||||
.args(["--query-gpu=uuid", "--format=csv,noheader"])
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
String::from_utf8_lossy(&output.stdout)
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.count()
|
||||
.to_string()
|
||||
} else {
|
||||
devices
|
||||
}
|
||||
}
|
||||
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
|
||||
},
|
||||
};
|
||||
|
@ -1,6 +1,7 @@
|
||||
{
|
||||
buildPythonPackage,
|
||||
poetry-core,
|
||||
aiohttp,
|
||||
huggingface-hub,
|
||||
pydantic,
|
||||
}:
|
||||
@ -15,6 +16,7 @@ buildPythonPackage {
|
||||
build-system = [ poetry-core ];
|
||||
|
||||
dependencies = [
|
||||
aiohttp
|
||||
huggingface-hub
|
||||
pydantic
|
||||
];
|
||||
|
@ -13,26 +13,26 @@ final: prev: {
|
||||
(
|
||||
python-self: python-super: with python-self; {
|
||||
# Python package override example:
|
||||
transformers = python-super.transformers.overrideAttrs (
|
||||
_: _: {
|
||||
src = final.fetchFromGitHub {
|
||||
owner = "huggingface";
|
||||
repo = "transformers";
|
||||
rev = "v4.51.0";
|
||||
hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8=";
|
||||
};
|
||||
}
|
||||
);
|
||||
huggingface-hub = python-super.huggingface-hub.overrideAttrs (
|
||||
_: _: {
|
||||
src = final.fetchFromGitHub {
|
||||
owner = "huggingface";
|
||||
repo = "huggingface_hub";
|
||||
rev = "v0.30.0";
|
||||
hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o=";
|
||||
};
|
||||
}
|
||||
);
|
||||
#transformers = python-super.transformers.overrideAttrs (
|
||||
# _: _: {
|
||||
# src = final.fetchFromGitHub {
|
||||
# owner = "huggingface";
|
||||
# repo = "transformers";
|
||||
# rev = "v4.51.0";
|
||||
# hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8=";
|
||||
# };
|
||||
# }
|
||||
#);
|
||||
#huggingface-hub = python-super.huggingface-hub.overrideAttrs (
|
||||
# _: _: {
|
||||
# src = final.fetchFromGitHub {
|
||||
# owner = "huggingface";
|
||||
# repo = "huggingface_hub";
|
||||
# rev = "v0.30.0";
|
||||
# hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o=";
|
||||
# };
|
||||
# }
|
||||
#);
|
||||
}
|
||||
)
|
||||
];
|
||||
|
@ -31,7 +31,7 @@
|
||||
peft,
|
||||
pillow,
|
||||
prometheus-client,
|
||||
punica-kernels,
|
||||
punica-sgmv,
|
||||
py-cpuinfo,
|
||||
pydantic,
|
||||
quantization,
|
||||
@ -107,7 +107,7 @@ buildPythonPackage {
|
||||
peft
|
||||
pillow
|
||||
prometheus-client
|
||||
punica-kernels
|
||||
punica-sgmv
|
||||
py-cpuinfo
|
||||
pydantic
|
||||
quantization
|
||||
|
@ -3,7 +3,6 @@ include Makefile-flash-att-v2
|
||||
include Makefile-vllm
|
||||
include Makefile-awq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
include Makefile-exllamav2
|
||||
include Makefile-flashinfer
|
||||
|
||||
|
@ -1,12 +0,0 @@
|
||||
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
||||
|
||||
build-lorax-punica:
|
||||
if [ ! -d 'lorax-punica' ]; then \
|
||||
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
|
||||
fi
|
||||
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
||||
cd lorax-punica && git submodule update --init --recursive
|
||||
cd lorax-punica/server/punica_kernels && python setup.py build
|
||||
|
||||
install-lorax-punica: build-lorax-punica
|
||||
cd lorax-punica/server/punica_kernels && python setup.py install
|
@ -163,6 +163,64 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/punica-sgmv",
|
||||
"sha": "9ae1b469cb39c33df9ddd61657c6359acc423714",
|
||||
"variants": {
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-aarch64-linux": {
|
||||
"hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu128-aarch64-linux": {
|
||||
"hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu128-x86_64-linux": {
|
||||
"hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/quantization",
|
||||
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",
|
||||
|
@ -58,6 +58,7 @@ build-backend = "setuptools.build_meta"
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-community/paged-attention" = ">=0.0.2"
|
||||
"kernels-community/moe" = ">=0.1.1"
|
||||
"kernels-community/punica-sgmv" = ">=0.0.1"
|
||||
"kernels-community/quantization" = ">=0.0.3"
|
||||
"kernels-community/quantization-eetq" = ">=0.0.1"
|
||||
"kernels-community/rotary" = ">=0.0.1"
|
||||
|
@ -13,21 +13,20 @@ from torch.distributed import ProcessGroup
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchMetadata,
|
||||
AdapterWeights,
|
||||
BatchAdapterWeights,
|
||||
)
|
||||
from text_generation_server.utils.sgmv import (
|
||||
BGMV_MAX_RANK,
|
||||
MAX_RANK_CUSTOM,
|
||||
get_tmp_tensors,
|
||||
orient_for_rank,
|
||||
pad_rank,
|
||||
use_cutlass_shrink,
|
||||
has_sgmv,
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
punica_sgmv = load_kernel(
|
||||
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
|
||||
)
|
||||
else:
|
||||
punica_sgmv = None
|
||||
|
||||
|
||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||
@ -129,11 +128,13 @@ class LoraWeights(AdapterWeights):
|
||||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||
|
||||
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
|
||||
self._is_transposed = False
|
||||
|
||||
# [num_layers, hidden_size, r]
|
||||
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||
weights_a = [
|
||||
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
|
||||
]
|
||||
self._weights_a = torch.stack(weights_a)
|
||||
|
||||
# [num_layers, r, hidden_size]
|
||||
@ -244,8 +245,12 @@ class LoraWeights(AdapterWeights):
|
||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||
|
||||
# pad lora ranks to be compatible with sgmv
|
||||
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
||||
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
||||
lora_a_list = [
|
||||
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
|
||||
]
|
||||
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
@ -293,7 +298,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||
|
||||
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||
return all(
|
||||
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||
rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM
|
||||
for rank_data in self.rank_data.values()
|
||||
)
|
||||
|
||||
@ -337,8 +342,8 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||
)
|
||||
|
||||
use_sgmv = False
|
||||
if prefill or max_rank > BGMV_MAX_RANK:
|
||||
if has_sgmv():
|
||||
if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK:
|
||||
if punica_sgmv is not None:
|
||||
use_sgmv = True
|
||||
lora_a_ptr = torch.tensor(
|
||||
[
|
||||
@ -425,7 +430,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||
|
||||
if use_sgmv:
|
||||
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||
tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors(
|
||||
lora_a_ptr_indices.size(0), rank, device
|
||||
)
|
||||
segment_starts = meta.adapter_segments[indices]
|
||||
|
@ -5,14 +5,16 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.utils.sgmv import (
|
||||
add_lora_a_bgmv,
|
||||
add_lora_b_bgmv,
|
||||
has_sgmv,
|
||||
lora_a_sgmv_cutlass,
|
||||
lora_b_sgmv_cutlass,
|
||||
orient_for_rank,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
punica_sgmv = load_kernel(
|
||||
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
|
||||
)
|
||||
else:
|
||||
punica_sgmv = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
@ -41,7 +43,11 @@ class LoraLinear(nn.Module):
|
||||
return result
|
||||
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
if (
|
||||
punica_sgmv is not None
|
||||
and data is not None
|
||||
and data.can_vectorize(self.process_group)
|
||||
):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
# The 'result' tensor represents the full output, which can vary in size based on
|
||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||
@ -68,7 +74,7 @@ class LoraLinear(nn.Module):
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
v = lora_a_sgmv_cutlass(
|
||||
v = punica_sgmv.lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
@ -81,7 +87,7 @@ class LoraLinear(nn.Module):
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
lora_b_sgmv_cutlass(
|
||||
punica_sgmv.lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
@ -96,7 +102,7 @@ class LoraLinear(nn.Module):
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
add_lora_a_bgmv(
|
||||
punica_sgmv.add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
@ -107,7 +113,7 @@ class LoraLinear(nn.Module):
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
add_lora_b_bgmv(
|
||||
punica_sgmv.add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
@ -142,7 +148,7 @@ class LoraLinear(nn.Module):
|
||||
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||
|
||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||
lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0))
|
||||
|
||||
a_out = input @ lora_a
|
||||
if self.process_group.size() > 1:
|
||||
|
@ -1,252 +0,0 @@
|
||||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/utils/sgmv.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
import punica_kernels as _kernels
|
||||
|
||||
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
||||
except ImportError:
|
||||
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
|
||||
_kernels = None
|
||||
HAS_SGMV = False
|
||||
|
||||
|
||||
MIN_SGMV_RANK = 8
|
||||
MIN_RANK_CUSTOM = 16
|
||||
MAX_RANK_CUSTOM = 128
|
||||
SGMV_BLOCK_SIZE = 16
|
||||
BGMV_MAX_RANK = 64
|
||||
|
||||
|
||||
def has_sgmv() -> bool:
|
||||
return HAS_SGMV
|
||||
|
||||
|
||||
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
||||
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
||||
if not has_sgmv():
|
||||
return t
|
||||
|
||||
# tensor parallelism will result in effective rank being divided by world_size,
|
||||
# so we need to scale the min rank to offset that effect
|
||||
min_rank = MIN_SGMV_RANK * world_size
|
||||
|
||||
# if we're at or below the min rank, pad up to the min rank
|
||||
# otherwise, pad to the nearest multiple of the block size
|
||||
current_rank = t.size(dim)
|
||||
target_rank = (
|
||||
min_rank
|
||||
if current_rank <= min_rank
|
||||
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
||||
)
|
||||
if current_rank == target_rank:
|
||||
return t
|
||||
|
||||
pad_size = target_rank - current_rank
|
||||
|
||||
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
||||
pad = [0, 0] * t.dim()
|
||||
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
||||
pad = tuple(pad)
|
||||
|
||||
return F.pad(t, pad, mode="constant", value=0.0)
|
||||
|
||||
|
||||
def use_cutlass_shrink(lora_rank: int) -> bool:
|
||||
return lora_rank < MIN_RANK_CUSTOM
|
||||
|
||||
|
||||
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
||||
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
||||
return t.transpose(0, 1)
|
||||
return t
|
||||
|
||||
|
||||
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
|
||||
def add_lora_sgmv_cutlass(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.Tensor,
|
||||
s_end: torch.Tensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
):
|
||||
"""
|
||||
Semantics:
|
||||
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||
Weight matrix shape: `[num_layers, R, H1]`.
|
||||
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||
Weight matrix shape: `[num_layers, R, H2]`.
|
||||
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
||||
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
||||
layer_idx: Layer index of the weight matrices.
|
||||
"""
|
||||
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
||||
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
||||
_add_lora_sgmv_cutlass_legacy(
|
||||
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
|
||||
)
|
||||
return
|
||||
|
||||
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
||||
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
||||
|
||||
|
||||
def _add_lora_sgmv_cutlass_legacy(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
):
|
||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
||||
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
|
||||
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
|
||||
return torch.empty((size,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
def get_tmp_expand_size(size: int) -> int:
|
||||
return _kernels.sgmv_cutlass_tmp_size(size)
|
||||
|
||||
|
||||
def get_tmp_tensors(
|
||||
nsegments: int, lora_rank: int, device: torch.device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv()
|
||||
has_sgmv_available = has_sgmv()
|
||||
|
||||
if use_cutlass:
|
||||
tmp = get_tmp_tensor_for_size(nsegments, device)
|
||||
return tmp, tmp
|
||||
elif has_sgmv_available:
|
||||
return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device)
|
||||
else:
|
||||
tmp = get_tmp_tensor_for_size(nsegments, device)
|
||||
return tmp, tmp
|
||||
|
||||
|
||||
def lora_a_sgmv_cutlass(
|
||||
x: torch.Tensor,
|
||||
tmp: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
) -> torch.Tensor:
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
else:
|
||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
return v
|
||||
|
||||
|
||||
def lora_b_sgmv_cutlass(
|
||||
y: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
tmp: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||
|
||||
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
v: Shape: `[B, R]`. Temporary vector.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
|
||||
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
"""
|
||||
|
||||
|
||||
def add_lora_a_bgmv(
|
||||
v: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_T_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
||||
|
||||
|
||||
def add_lora_b_bgmv(
|
||||
y: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
wb_T_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
||||
|
||||
|
||||
def segmented_matmul(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w: List[torch.Tensor],
|
||||
b: List[torch.Tensor],
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
):
|
||||
for i in range(len(w)):
|
||||
if s_end[i] - s_start[i] <= 0:
|
||||
continue
|
||||
|
||||
xi = x[s_start[i] : s_end[i]]
|
||||
wi = w[i]
|
||||
bi = b[i]
|
||||
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
|
Loading…
Reference in New Issue
Block a user