From c6071749db61208dc22f658689f37f4eb803bde6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Sep 2025 19:47:03 +0200 Subject: [PATCH] Fix mask passed to flashinfer (#3324) Custom masks are padded to the shape `[batch_size, max_len, max_len]`. However, flashinfer expects an unpadded mask of the shape `[sum(q_len[i] * k_len[i] for i in range(batch_size)]`. This change unpads the custom mask (currently only used by Gemma 3) to this shape (assuming q_len == k_len, since we only use the custom mask during prefill). --- .../layers/attention/flashinfer.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index f78475d5..daea3f9b 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -1,6 +1,7 @@ from typing import Optional from contextvars import ContextVar from contextlib import contextmanager +import math import flashinfer import torch @@ -20,6 +21,20 @@ decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = Contex workspace: Optional[torch.Tensor] = None +def unpad_2d_mask( + attention_mask: torch.Tensor, seq_lengths: torch.Tensor +) -> torch.Tensor: + # Like torch unpad_sequence, but for 2D masks. + unpadded_tensors = [] + for i, length in enumerate(seq_lengths): + unpadded_matrix = attention_mask[i, :length, :length] + unpadded_tensors.append(unpadded_matrix.flatten()) + + packed_tensor = torch.cat(unpadded_tensors) + + return packed_tensor + + def get_workspace(device): """Get shared flashinfer workspace.""" global workspace @@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state( last_page_len += 1 token = prefill_with_paged_kv_state.set(state) + + # Attention masks are padded, unpad. + if custom_mask is not None: + bs = input_lengths.shape[0] + seq_len = math.isqrt(custom_mask.numel() // bs) + custom_mask = unpad_2d_mask( + custom_mask.reshape(bs, seq_len, seq_len), input_lengths + ) + try: state.plan( qo_indptr=cu_seqlens,