diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index f5031d6f..09a05585 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -156,7 +156,7 @@ def prepare_for_decode( block_groups_device, num_classes=batch_size ) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) + mask = mask >= block_usage_device.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) return trim_attn_metadata( HPUPagedAttentionMetadata(