Upgrade to new vllm extension ops for Gaudi backend (fix issue in exponential bucketing) (#3239)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-22 21:29:16 +08:00 committed by GitHub
parent 674c514d44
commit f08b44ade5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 13 deletions

View File

@ -98,7 +98,7 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir pip install . --no-cache-dir
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794 RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -7,6 +7,7 @@ from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA
import os import os
from text_generation_server.models.globals import BLOCK_SIZE
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
@ -126,6 +127,7 @@ def paged_attention(
block_mapping=hpu_attention_meta.block_mapping, block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias, block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups, block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale, scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
@ -160,6 +162,7 @@ def paged_attention_mla(
block_mapping=hpu_attention_meta.block_mapping, block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias, block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups, block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale, scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),

View File

@ -5,7 +5,6 @@ import torch
from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass @dataclass
@ -55,12 +54,12 @@ class KVCache:
self.kv_cache = ( self.kv_cache = (
torch.zeros( torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.zeros( torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
@ -129,7 +128,7 @@ class KVCompressCache(KVCache):
raise ValueError("torch.float8_e5m2 is not supported in hpu. ") raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros( self.kv_cache = torch.zeros(
(num_blocks, BLOCK_SIZE, 1, head_size), (num_blocks * BLOCK_SIZE, 1, head_size),
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
@ -161,14 +160,11 @@ class KVCompressCache(KVCache):
): ):
"""Store the key and value at the given slots.""" """Store the key and value at the given slots."""
## TODO FP8 kv cache support ## TODO FP8 kv cache support
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if self.kv_cache.dtype == torch.float8_e4m3fn: if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2( key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0] )[0]
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset) self.kv_cache.index_copy_(0, slots, key)
def paged_reshape_and_cache( def paged_reshape_and_cache(
@ -180,8 +176,6 @@ def paged_reshape_and_cache(
k_scale: torch.Tensor, k_scale: torch.Tensor,
v_scale: torch.Tensor, v_scale: torch.Tensor,
): ):
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if key_cache.dtype == torch.float8_e4m3fn: if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2( key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn key, k_scale, False, False, torch.float8_e4m3fn
@ -189,8 +183,8 @@ def paged_reshape_and_cache(
value = torch.ops.hpu.cast_to_fp8_v2( value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn value, v_scale, False, False, torch.float8_e4m3fn
)[0] )[0]
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) key_cache.index_copy_(0, slots, key)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) value_cache.index_copy_(0, slots, value)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales: def get_kv_scales(weights: Weights, prefix: str) -> KVScales: