Fix mask passed to flashinfer (#3324)

Custom masks are padded to the shape `[batch_size, max_len, max_len]`.
However, flashinfer expects an unpadded mask of the shape
`[sum(q_len[i] * k_len[i] for i in range(batch_size)]`.

This change unpads the custom mask (currently only used by Gemma 3)
to this shape (assuming q_len == k_len, since we only use the custom
mask during prefill).
This commit is contained in:
Daniël de Kok 2025-09-08 19:47:03 +02:00 committed by GitHub
parent 4f067c22c3
commit c6071749db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager
import math
import flashinfer
import torch
@ -20,6 +21,20 @@ decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = Contex
workspace: Optional[torch.Tensor] = None
def unpad_2d_mask(
attention_mask: torch.Tensor, seq_lengths: torch.Tensor
) -> torch.Tensor:
# Like torch unpad_sequence, but for 2D masks.
unpadded_tensors = []
for i, length in enumerate(seq_lengths):
unpadded_matrix = attention_mask[i, :length, :length]
unpadded_tensors.append(unpadded_matrix.flatten())
packed_tensor = torch.cat(unpadded_tensors)
return packed_tensor
def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state(
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
# Attention masks are padded, unpad.
if custom_mask is not None:
bs = input_lengths.shape[0]
seq_len = math.isqrt(custom_mask.numel() // bs)
custom_mask = unpad_2d_mask(
custom_mask.reshape(bs, seq_len, seq_len), input_lengths
)
try:
state.plan(
qo_indptr=cu_seqlens,