diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index f78475d5..daea3f9b 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -1,6 +1,7 @@ from typing import Optional from contextvars import ContextVar from contextlib import contextmanager +import math import flashinfer import torch @@ -20,6 +21,20 @@ decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = Contex workspace: Optional[torch.Tensor] = None +def unpad_2d_mask( + attention_mask: torch.Tensor, seq_lengths: torch.Tensor +) -> torch.Tensor: + # Like torch unpad_sequence, but for 2D masks. + unpadded_tensors = [] + for i, length in enumerate(seq_lengths): + unpadded_matrix = attention_mask[i, :length, :length] + unpadded_tensors.append(unpadded_matrix.flatten()) + + packed_tensor = torch.cat(unpadded_tensors) + + return packed_tensor + + def get_workspace(device): """Get shared flashinfer workspace.""" global workspace @@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state( last_page_len += 1 token = prefill_with_paged_kv_state.set(state) + + # Attention masks are padded, unpad. + if custom_mask is not None: + bs = input_lengths.shape[0] + seq_len = math.isqrt(custom_mask.numel() // bs) + custom_mask = unpad_2d_mask( + custom_mask.reshape(bs, seq_len, seq_len), input_lengths + ) + try: state.plan( qo_indptr=cu_seqlens,