mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Custom masks are padded to the shape `[batch_size, max_len, max_len]`. However, flashinfer expects an unpadded mask of the shape `[sum(q_len[i] * k_len[i] for i in range(batch_size)]`. This change unpads the custom mask (currently only used by Gemma 3) to this shape (assuming q_len == k_len, since we only use the custom mask during prefill). |
||
---|---|---|
.. | ||
__init__.py | ||
common.py | ||
cuda.py | ||
flash_attn_triton.py | ||
flashinfer.py | ||
ipex.py | ||
kv_cache.py | ||
rocm.py |