mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Fix mask passed to flashinfer
Custom masks are padded to the shape `[batch_size, max_len, max_len]`. However, flashinfer expects an unpadded mask of the shape `[sum(q_len[i] * k_len[i] for i in range(batch_size)]`. This change unpads the custom mask (currently only used by Gemma 3) to this shape (assuming q_len == k_len, since we only use the custom mask during prefill).
This commit is contained in:
parent
9dedeb89ac
commit
70e5f6c5dc
@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
from contextvars import ContextVar
|
||||
from contextlib import contextmanager
|
||||
import math
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
@ -20,6 +21,20 @@ decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = Contex
|
||||
workspace: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def unpad_2d_mask(
|
||||
attention_mask: torch.Tensor, seq_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# Like torch unpad_sequence, but for 2D masks.
|
||||
unpadded_tensors = []
|
||||
for i, length in enumerate(seq_lengths):
|
||||
unpadded_matrix = attention_mask[i, :length, :length]
|
||||
unpadded_tensors.append(unpadded_matrix.flatten())
|
||||
|
||||
packed_tensor = torch.cat(unpadded_tensors)
|
||||
|
||||
return packed_tensor
|
||||
|
||||
|
||||
def get_workspace(device):
|
||||
"""Get shared flashinfer workspace."""
|
||||
global workspace
|
||||
@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state(
|
||||
last_page_len += 1
|
||||
|
||||
token = prefill_with_paged_kv_state.set(state)
|
||||
|
||||
# Attention masks are padded, unpad.
|
||||
if custom_mask is not None:
|
||||
bs = input_lengths.shape[0]
|
||||
seq_len = math.isqrt(custom_mask.numel() // bs)
|
||||
custom_mask = unpad_2d_mask(
|
||||
custom_mask.reshape(bs, seq_len, seq_len), input_lengths
|
||||
)
|
||||
|
||||
try:
|
||||
state.plan(
|
||||
qo_indptr=cu_seqlens,
|
||||
|
Loading…
Reference in New Issue
Block a user