remove block_scales which is not needed anymore

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-11 01:27:49 -07:00
parent a83e9fe003
commit 76cc129796
3 changed files with 0 additions and 10 deletions

View File

@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata:
block_list: Optional[torch.Tensor] block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor] block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor] block_usage: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor] block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor] attn_bias: Optional[torch.Tensor]
@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
"block_list", "block_list",
"block_mapping", "block_mapping",
"block_usage", "block_usage",
"block_scales",
"block_groups", "block_groups",
"attn_bias", "attn_bias",
], ],

View File

@ -74,7 +74,6 @@ def paged_attention(
block_list=hpu_attention_meta.block_list, block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping, block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias, block_bias=hpu_attention_meta.attn_bias,
block_scales=hpu_attention_meta.block_scales,
block_groups=hpu_attention_meta.block_groups, block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale, scale=softmax_scale,
matmul_qk_op=Matmul(), matmul_qk_op=Matmul(),

View File

@ -70,7 +70,6 @@ from text_generation_server.utils.import_utils import (
import vllm_hpu_extension.environment as environment import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import itertools import itertools
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.bucketing import HPUBucketingContext from vllm_hpu_extension.bucketing import HPUBucketingContext
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -149,11 +148,6 @@ def prepare_for_decode(
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1) mask = mask >= block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
ones = torch.ones(
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
return trim_attn_metadata( return trim_attn_metadata(
HPUPagedAttentionMetadata( HPUPagedAttentionMetadata(
block_list=block_list, block_list=block_list,
@ -161,7 +155,6 @@ def prepare_for_decode(
block_usage=block_usage, block_usage=block_usage,
block_mapping=block_mapping.to(dtype), block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias, attn_bias=attn_bias,
block_scales=block_scales,
) )
) )