Fix mask passed to flashinfer

Custom masks are padded to the shape `[batch_size, max_len, max_len]`.
However, flashinfer expects an unpadded mask of the shape
`[sum(q_len[i] * k_len[i] for i in range(batch_size)]`.

This change unpads the custom mask (currently only used by Gemma 3)
to this shape (assuming q_len == k_len, since we only use the custom
mask during prefill).
This commit is contained in:
Daniël de Kok 2025-09-08 11:35:58 +00:00
parent 9dedeb89ac
commit 70e5f6c5dc

View File

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
from contextvars import ContextVar from contextvars import ContextVar
from contextlib import contextmanager from contextlib import contextmanager
import math
import flashinfer import flashinfer
import torch import torch
@ -20,6 +21,20 @@ decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = Contex
workspace: Optional[torch.Tensor] = None workspace: Optional[torch.Tensor] = None
def unpad_2d_mask(
attention_mask: torch.Tensor, seq_lengths: torch.Tensor
) -> torch.Tensor:
# Like torch unpad_sequence, but for 2D masks.
unpadded_tensors = []
for i, length in enumerate(seq_lengths):
unpadded_matrix = attention_mask[i, :length, :length]
unpadded_tensors.append(unpadded_matrix.flatten())
packed_tensor = torch.cat(unpadded_tensors)
return packed_tensor
def get_workspace(device): def get_workspace(device):
"""Get shared flashinfer workspace.""" """Get shared flashinfer workspace."""
global workspace global workspace
@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state(
last_page_len += 1 last_page_len += 1
token = prefill_with_paged_kv_state.set(state) token = prefill_with_paged_kv_state.set(state)
# Attention masks are padded, unpad.
if custom_mask is not None:
bs = input_lengths.shape[0]
seq_len = math.isqrt(custom_mask.numel() // bs)
custom_mask = unpad_2d_mask(
custom_mask.reshape(bs, seq_len, seq_len), input_lengths
)
try: try:
state.plan( state.plan(
qo_indptr=cu_seqlens, qo_indptr=cu_seqlens,