diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 8ec9fb461..34c770402 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata: block_list: Optional[torch.Tensor] block_mapping: Optional[torch.Tensor] block_usage: Optional[torch.Tensor] - block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] attn_bias: Optional[torch.Tensor] @@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: "block_list", "block_mapping", "block_usage", - "block_scales", "block_groups", "attn_bias", ], diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index f34e93abc..1d73dcb3b 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -74,7 +74,6 @@ def paged_attention( block_list=hpu_attention_meta.block_list, block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, - block_scales=hpu_attention_meta.block_scales, block_groups=hpu_attention_meta.block_groups, scale=softmax_scale, matmul_qk_op=Matmul(), diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index a2cbf30c0..d4ff3f707 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -70,7 +70,6 @@ from text_generation_server.utils.import_utils import ( import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools -from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.bucketing import HPUBucketingContext tracer = trace.get_tracer(__name__) @@ -149,11 +148,6 @@ def prepare_for_decode( mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) mask = mask >= block_usage.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) - ones = torch.ones( - (block_mapping.size(0),), device=device, dtype=block_mapping.dtype - ) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) return trim_attn_metadata( HPUPagedAttentionMetadata( block_list=block_list, @@ -161,7 +155,6 @@ def prepare_for_decode( block_usage=block_usage, block_mapping=block_mapping.to(dtype), attn_bias=attn_bias, - block_scales=block_scales, ) )