mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-23 20:12:06 +00:00
Upgrade to new vllm extension ops for Gaudi backend (fix issue in exponential bucketing) (#3239)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
674c514d44
commit
f08b44ade5
@ -98,7 +98,7 @@ RUN cd server && \
|
||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
|
||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
|
@ -7,6 +7,7 @@ from vllm_hpu_extension.utils import Matmul
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
import os
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
@ -126,6 +127,7 @@ def paged_attention(
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
block_size=BLOCK_SIZE,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||
@ -160,6 +162,7 @@ def paged_attention_mla(
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
block_size=BLOCK_SIZE,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||
|
@ -5,7 +5,6 @@ import torch
|
||||
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from vllm_hpu_extension import cache_ops
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -55,12 +54,12 @@ class KVCache:
|
||||
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
(num_blocks * BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
(num_blocks * BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
@ -129,7 +128,7 @@ class KVCompressCache(KVCache):
|
||||
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||
|
||||
self.kv_cache = torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, 1, head_size),
|
||||
(num_blocks * BLOCK_SIZE, 1, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
@ -161,14 +160,11 @@ class KVCompressCache(KVCache):
|
||||
):
|
||||
"""Store the key and value at the given slots."""
|
||||
## TODO FP8 kv cache support
|
||||
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
if self.kv_cache.dtype == torch.float8_e4m3fn:
|
||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
|
||||
self.kv_cache.index_copy_(0, slots, key)
|
||||
|
||||
|
||||
def paged_reshape_and_cache(
|
||||
@ -180,8 +176,6 @@ def paged_reshape_and_cache(
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
):
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||
key, k_scale, False, False, torch.float8_e4m3fn
|
||||
@ -189,8 +183,8 @@ def paged_reshape_and_cache(
|
||||
value = torch.ops.hpu.cast_to_fp8_v2(
|
||||
value, v_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||
key_cache.index_copy_(0, slots, key)
|
||||
value_cache.index_copy_(0, slots, value)
|
||||
|
||||
|
||||
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||
|
Loading…
Reference in New Issue
Block a user