Upgrade to new vllm extension ops for Gaudi backend (fix issue in exponential bucketing) (#3239)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-22 21:29:16 +08:00 committed by GitHub
parent 674c514d44
commit f08b44ade5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 13 deletions

View File

@ -98,7 +98,7 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -7,6 +7,7 @@ from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
SUPPORTS_WINDOWING = False
@ -126,6 +127,7 @@ def paged_attention(
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
@ -160,6 +162,7 @@ def paged_attention_mla(
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),

View File

@ -5,7 +5,6 @@ import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass
@ -55,12 +54,12 @@ class KVCache:
self.kv_cache = (
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
@ -129,7 +128,7 @@ class KVCompressCache(KVCache):
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks, BLOCK_SIZE, 1, head_size),
(num_blocks * BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@ -161,14 +160,11 @@ class KVCompressCache(KVCache):
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
self.kv_cache.index_copy_(0, slots, key)
def paged_reshape_and_cache(
@ -180,8 +176,6 @@ def paged_reshape_and_cache(
k_scale: torch.Tensor,
v_scale: torch.Tensor,
):
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
@ -189,8 +183,8 @@ def paged_reshape_and_cache(
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
key_cache.index_copy_(0, slots, key)
value_cache.index_copy_(0, slots, value)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales: