diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index bd6c58b4..c4164556 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -98,7 +98,7 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir -RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794 +RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 1c2e37c7..8cca7a29 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -7,6 +7,7 @@ from vllm_hpu_extension.utils import Matmul from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA import os +from text_generation_server.models.globals import BLOCK_SIZE SUPPORTS_WINDOWING = False @@ -126,6 +127,7 @@ def paged_attention( block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, + block_size=BLOCK_SIZE, scale=softmax_scale, matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), @@ -160,6 +162,7 @@ def paged_attention_mla( block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, + block_size=BLOCK_SIZE, scale=softmax_scale, matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index cdd1e1d7..723c1ec0 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -5,7 +5,6 @@ import torch from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.utils.weights import Weights -from vllm_hpu_extension import cache_ops @dataclass @@ -55,12 +54,12 @@ class KVCache: self.kv_cache = ( torch.zeros( - (num_blocks, BLOCK_SIZE, num_heads, head_size), + (num_blocks * BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, ), torch.zeros( - (num_blocks, BLOCK_SIZE, num_heads, head_size), + (num_blocks * BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, ), @@ -129,7 +128,7 @@ class KVCompressCache(KVCache): raise ValueError("torch.float8_e5m2 is not supported in hpu. ") self.kv_cache = torch.zeros( - (num_blocks, BLOCK_SIZE, 1, head_size), + (num_blocks * BLOCK_SIZE, 1, head_size), dtype=dtype, device=device, ) @@ -161,14 +160,11 @@ class KVCompressCache(KVCache): ): """Store the key and value at the given slots.""" ## TODO FP8 kv cache support - - block_idx = slots // BLOCK_SIZE - block_offset = slots % BLOCK_SIZE if self.kv_cache.dtype == torch.float8_e4m3fn: key = torch.ops.hpu.cast_to_fp8_v2( key, kv_scales.key_scale, False, False, torch.float8_e4m3fn )[0] - cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset) + self.kv_cache.index_copy_(0, slots, key) def paged_reshape_and_cache( @@ -180,8 +176,6 @@ def paged_reshape_and_cache( k_scale: torch.Tensor, v_scale: torch.Tensor, ): - block_idx = slots // BLOCK_SIZE - block_offset = slots % BLOCK_SIZE if key_cache.dtype == torch.float8_e4m3fn: key = torch.ops.hpu.cast_to_fp8_v2( key, k_scale, False, False, torch.float8_e4m3fn @@ -189,8 +183,8 @@ def paged_reshape_and_cache( value = torch.ops.hpu.cast_to_fp8_v2( value, v_scale, False, False, torch.float8_e4m3fn )[0] - cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) - cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + key_cache.index_copy_(0, slots, key) + value_cache.index_copy_(0, slots, value) def get_kv_scales(weights: Weights, prefix: str) -> KVScales: