diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 526dbceca..f34e93abc 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -68,7 +68,7 @@ def paged_attention( ): batch_size, head_num, head_size = query.shape output = ops.flat_pa( - query=query, + query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, value_cache=kv_cache.value, block_list=hpu_attention_meta.block_list, diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 6c8d637e5..0dc5cdafd 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -11,7 +11,7 @@ from text_generation_server.utils.weights import ( ) from vllm_hpu_extension.ops import scaled_fp8_quant -from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 +from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 import habana_frameworks.torch.utils.experimental as htexp w8a8_block_fp8_matmul = None