mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
flashinfer: switch to plan API (#2904)
This change doesn't switch `forward` to `run` yet, since it requires that we have access to the softmax scale and the logit softcap outside the model.
This commit is contained in:
parent
8f6146f11a
commit
630f198624
@ -235,7 +235,6 @@ def attention(
|
|||||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
|
||||||
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
)
|
)
|
||||||
|
@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state(
|
|||||||
|
|
||||||
token = prefill_with_paged_kv_state.set(state)
|
token = prefill_with_paged_kv_state.set(state)
|
||||||
try:
|
try:
|
||||||
state.begin_forward(
|
state.plan(
|
||||||
qo_indptr=cu_seqlens,
|
qo_indptr=cu_seqlens,
|
||||||
paged_kv_indptr=indptr,
|
paged_kv_indptr=indptr,
|
||||||
paged_kv_indices=block_tables,
|
paged_kv_indices=block_tables,
|
||||||
@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state(
|
|||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
prefill_with_paged_kv_state.reset(token)
|
prefill_with_paged_kv_state.reset(token)
|
||||||
|
|
||||||
@ -200,7 +199,7 @@ def use_decode_state(
|
|||||||
token = decode_state.set(state)
|
token = decode_state.set(state)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state.begin_forward(
|
state.plan(
|
||||||
indptr=indptr,
|
indptr=indptr,
|
||||||
indices=block_tables,
|
indices=block_tables,
|
||||||
last_page_len=last_page_len,
|
last_page_len=last_page_len,
|
||||||
@ -214,6 +213,5 @@ def use_decode_state(
|
|||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
decode_state.reset(token)
|
decode_state.reset(token)
|
||||||
|
Loading…
Reference in New Issue
Block a user