remove block_scales which is not needed anymore

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-11 01:27:49 -07:00
parent a83e9fe003
commit 76cc129796
3 changed files with 0 additions and 10 deletions

View File

@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata:
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]
@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
"block_list",
"block_mapping",
"block_usage",
"block_scales",
"block_groups",
"attn_bias",
],

View File

@ -74,7 +74,6 @@ def paged_attention(
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_scales=hpu_attention_meta.block_scales,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=Matmul(),

View File

@ -70,7 +70,6 @@ from text_generation_server.utils.import_utils import (
import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.bucketing import HPUBucketingContext
tracer = trace.get_tracer(__name__)
@ -149,11 +148,6 @@ def prepare_for_decode(
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
ones = torch.ones(
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list,
@ -161,7 +155,6 @@ def prepare_for_decode(
block_usage=block_usage,
block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias,
block_scales=block_scales,
)
)