flashinfer: remove contiguous calls

This commit is contained in:
Daniël de Kok 2025-01-06 16:07:58 +00:00
parent 02e3dc49be
commit 6defe57d7a

View File

@ -60,8 +60,7 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward( return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. query,
query.contiguous(),
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
@ -231,8 +230,7 @@ def attention(
softcap = 0.0 softcap = 0.0
return prefill_with_paged_kv_state.get().forward( return prefill_with_paged_kv_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. query,
query.contiguous(),
causal=causal, causal=causal,
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,