minor fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-15 23:34:48 -07:00
parent a184ce3876
commit b5e1ae9209

View File

@ -156,7 +156,7 @@ def prepare_for_decode(
block_groups_device, num_classes=batch_size
)
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1)
mask = mask >= block_usage_device.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
return trim_attn_metadata(
HPUPagedAttentionMetadata(