flashinfer: switch to plan API (#2904)

This change doesn't switch `forward` to `run` yet, since it requires
that we have access to the softmax scale and the logit softcap outside
the model.
This commit is contained in:
Daniël de Kok 2025-01-17 18:18:02 +01:00 committed by GitHub
parent 8f6146f11a
commit 630f198624
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 5 deletions

View File

@ -235,7 +235,6 @@ def attention(
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
) )

View File

@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state(
token = prefill_with_paged_kv_state.set(state) token = prefill_with_paged_kv_state.set(state)
try: try:
state.begin_forward( state.plan(
qo_indptr=cu_seqlens, qo_indptr=cu_seqlens,
paged_kv_indptr=indptr, paged_kv_indptr=indptr,
paged_kv_indices=block_tables, paged_kv_indices=block_tables,
@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state(
) )
yield yield
finally: finally:
state.end_forward()
if token is not None: if token is not None:
prefill_with_paged_kv_state.reset(token) prefill_with_paged_kv_state.reset(token)
@ -200,7 +199,7 @@ def use_decode_state(
token = decode_state.set(state) token = decode_state.set(state)
try: try:
state.begin_forward( state.plan(
indptr=indptr, indptr=indptr,
indices=block_tables, indices=block_tables,
last_page_len=last_page_len, last_page_len=last_page_len,
@ -214,6 +213,5 @@ def use_decode_state(
) )
yield yield
finally: finally:
state.end_forward()
if token is not None: if token is not None:
decode_state.reset(token) decode_state.reset(token)