mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix mask passed to flashinfer
Custom masks are padded to the shape `[batch_size, max_len, max_len]`. However, flashinfer expects an unpadded mask of the shape `[sum(q_len[i] * k_len[i] for i in range(batch_size)]`. This change unpads the custom mask (currently only used by Gemma 3) to this shape (assuming q_len == k_len, since we only use the custom mask during prefill).
This commit is contained in:
parent
9dedeb89ac
commit
70e5f6c5dc
@ -1,6 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
import math
|
||||||
|
|
||||||
import flashinfer
|
import flashinfer
|
||||||
import torch
|
import torch
|
||||||
@ -20,6 +21,20 @@ decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = Contex
|
|||||||
workspace: Optional[torch.Tensor] = None
|
workspace: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def unpad_2d_mask(
|
||||||
|
attention_mask: torch.Tensor, seq_lengths: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Like torch unpad_sequence, but for 2D masks.
|
||||||
|
unpadded_tensors = []
|
||||||
|
for i, length in enumerate(seq_lengths):
|
||||||
|
unpadded_matrix = attention_mask[i, :length, :length]
|
||||||
|
unpadded_tensors.append(unpadded_matrix.flatten())
|
||||||
|
|
||||||
|
packed_tensor = torch.cat(unpadded_tensors)
|
||||||
|
|
||||||
|
return packed_tensor
|
||||||
|
|
||||||
|
|
||||||
def get_workspace(device):
|
def get_workspace(device):
|
||||||
"""Get shared flashinfer workspace."""
|
"""Get shared flashinfer workspace."""
|
||||||
global workspace
|
global workspace
|
||||||
@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state(
|
|||||||
last_page_len += 1
|
last_page_len += 1
|
||||||
|
|
||||||
token = prefill_with_paged_kv_state.set(state)
|
token = prefill_with_paged_kv_state.set(state)
|
||||||
|
|
||||||
|
# Attention masks are padded, unpad.
|
||||||
|
if custom_mask is not None:
|
||||||
|
bs = input_lengths.shape[0]
|
||||||
|
seq_len = math.isqrt(custom_mask.numel() // bs)
|
||||||
|
custom_mask = unpad_2d_mask(
|
||||||
|
custom_mask.reshape(bs, seq_len, seq_len), input_lengths
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state.plan(
|
state.plan(
|
||||||
qo_indptr=cu_seqlens,
|
qo_indptr=cu_seqlens,
|
||||||
|
Loading…
Reference in New Issue
Block a user