mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-24 04:22:10 +00:00
Upgrade to new vllm extension ops for Gaudi backend (fix issue in exponential bucketing) (#3239)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
674c514d44
commit
f08b44ade5
@ -98,7 +98,7 @@ RUN cd server && \
|
|||||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||||
pip install . --no-cache-dir
|
pip install . --no-cache-dir
|
||||||
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
|
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
@ -7,6 +7,7 @@ from vllm_hpu_extension.utils import Matmul
|
|||||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
import os
|
import os
|
||||||
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
@ -126,6 +127,7 @@ def paged_attention(
|
|||||||
block_mapping=hpu_attention_meta.block_mapping,
|
block_mapping=hpu_attention_meta.block_mapping,
|
||||||
block_bias=hpu_attention_meta.attn_bias,
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
block_groups=hpu_attention_meta.block_groups,
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||||
@ -160,6 +162,7 @@ def paged_attention_mla(
|
|||||||
block_mapping=hpu_attention_meta.block_mapping,
|
block_mapping=hpu_attention_meta.block_mapping,
|
||||||
block_bias=hpu_attention_meta.attn_bias,
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
block_groups=hpu_attention_meta.block_groups,
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||||
|
@ -5,7 +5,6 @@ import torch
|
|||||||
|
|
||||||
from text_generation_server.models.globals import BLOCK_SIZE
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from vllm_hpu_extension import cache_ops
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -55,12 +54,12 @@ class KVCache:
|
|||||||
|
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks * BLOCK_SIZE, num_heads, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks * BLOCK_SIZE, num_heads, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
@ -129,7 +128,7 @@ class KVCompressCache(KVCache):
|
|||||||
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||||
|
|
||||||
self.kv_cache = torch.zeros(
|
self.kv_cache = torch.zeros(
|
||||||
(num_blocks, BLOCK_SIZE, 1, head_size),
|
(num_blocks * BLOCK_SIZE, 1, head_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@ -161,14 +160,11 @@ class KVCompressCache(KVCache):
|
|||||||
):
|
):
|
||||||
"""Store the key and value at the given slots."""
|
"""Store the key and value at the given slots."""
|
||||||
## TODO FP8 kv cache support
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
block_idx = slots // BLOCK_SIZE
|
|
||||||
block_offset = slots % BLOCK_SIZE
|
|
||||||
if self.kv_cache.dtype == torch.float8_e4m3fn:
|
if self.kv_cache.dtype == torch.float8_e4m3fn:
|
||||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
|
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
|
||||||
)[0]
|
)[0]
|
||||||
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
|
self.kv_cache.index_copy_(0, slots, key)
|
||||||
|
|
||||||
|
|
||||||
def paged_reshape_and_cache(
|
def paged_reshape_and_cache(
|
||||||
@ -180,8 +176,6 @@ def paged_reshape_and_cache(
|
|||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
v_scale: torch.Tensor,
|
v_scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
block_idx = slots // BLOCK_SIZE
|
|
||||||
block_offset = slots % BLOCK_SIZE
|
|
||||||
if key_cache.dtype == torch.float8_e4m3fn:
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
key, k_scale, False, False, torch.float8_e4m3fn
|
key, k_scale, False, False, torch.float8_e4m3fn
|
||||||
@ -189,8 +183,8 @@ def paged_reshape_and_cache(
|
|||||||
value = torch.ops.hpu.cast_to_fp8_v2(
|
value = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
value, v_scale, False, False, torch.float8_e4m3fn
|
value, v_scale, False, False, torch.float8_e4m3fn
|
||||||
)[0]
|
)[0]
|
||||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
key_cache.index_copy_(0, slots, key)
|
||||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
value_cache.index_copy_(0, slots, value)
|
||||||
|
|
||||||
|
|
||||||
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||||
|
Loading…
Reference in New Issue
Block a user