Fix mask passed to flashinfer (#3324)

Custom masks are padded to the shape `[batch_size, max_len, max_len]`.
However, flashinfer expects an unpadded mask of the shape
`[sum(q_len[i] * k_len[i] for i in range(batch_size)]`.

This change unpads the custom mask (currently only used by Gemma 3)
to this shape (assuming q_len == k_len, since we only use the custom
mask during prefill).
This commit is contained in:
Daniël de Kok 2025-09-08 19:47:03 +02:00 committed by GitHub
parent 4f067c22c3
commit c6071749db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
from contextvars import ContextVar from contextvars import ContextVar
from contextlib import contextmanager from contextlib import contextmanager
import math
import flashinfer import flashinfer
import torch import torch
@ -20,6 +21,20 @@ decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = Contex
workspace: Optional[torch.Tensor] = None workspace: Optional[torch.Tensor] = None
def unpad_2d_mask(
attention_mask: torch.Tensor, seq_lengths: torch.Tensor
) -> torch.Tensor:
# Like torch unpad_sequence, but for 2D masks.
unpadded_tensors = []
for i, length in enumerate(seq_lengths):
unpadded_matrix = attention_mask[i, :length, :length]
unpadded_tensors.append(unpadded_matrix.flatten())
packed_tensor = torch.cat(unpadded_tensors)
return packed_tensor
def get_workspace(device): def get_workspace(device):
"""Get shared flashinfer workspace.""" """Get shared flashinfer workspace."""
global workspace global workspace
@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state(
last_page_len += 1 last_page_len += 1
token = prefill_with_paged_kv_state.set(state) token = prefill_with_paged_kv_state.set(state)
# Attention masks are padded, unpad.
if custom_mask is not None:
bs = input_lengths.shape[0]
seq_len = math.isqrt(custom_mask.numel() // bs)
custom_mask = unpad_2d_mask(
custom_mask.reshape(bs, seq_len, seq_len), input_lengths
)
try: try:
state.plan( state.plan(
qo_indptr=cu_seqlens, qo_indptr=cu_seqlens,