mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
remove block_scales which is not needed anymore
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
a83e9fe003
commit
76cc129796
@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata:
|
||||
block_list: Optional[torch.Tensor]
|
||||
block_mapping: Optional[torch.Tensor]
|
||||
block_usage: Optional[torch.Tensor]
|
||||
block_scales: Optional[torch.Tensor]
|
||||
block_groups: Optional[torch.Tensor]
|
||||
attn_bias: Optional[torch.Tensor]
|
||||
|
||||
@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
||||
"block_list",
|
||||
"block_mapping",
|
||||
"block_usage",
|
||||
"block_scales",
|
||||
"block_groups",
|
||||
"attn_bias",
|
||||
],
|
||||
|
@ -74,7 +74,6 @@ def paged_attention(
|
||||
block_list=hpu_attention_meta.block_list,
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_scales=hpu_attention_meta.block_scales,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=Matmul(),
|
||||
|
@ -70,7 +70,6 @@ from text_generation_server.utils.import_utils import (
|
||||
import vllm_hpu_extension.environment as environment
|
||||
import habana_frameworks.torch as htorch
|
||||
import itertools
|
||||
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||
from vllm_hpu_extension.bucketing import HPUBucketingContext
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -149,11 +148,6 @@ def prepare_for_decode(
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= block_usage.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
ones = torch.ones(
|
||||
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
|
||||
)
|
||||
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
||||
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
||||
return trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
block_list=block_list,
|
||||
@ -161,7 +155,6 @@ def prepare_for_decode(
|
||||
block_usage=block_usage,
|
||||
block_mapping=block_mapping.to(dtype),
|
||||
attn_bias=attn_bias,
|
||||
block_scales=block_scales,
|
||||
)
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user