From 630f198624b6c405e5fcfb7f08f7f308026f68cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 17 Jan 2025 18:18:02 +0100 Subject: [PATCH] flashinfer: switch to plan API (#2904) This change doesn't switch `forward` to `run` yet, since it requires that we have access to the softmax scale and the logit softcap outside the model. --- server/text_generation_server/layers/attention/cuda.py | 1 - .../text_generation_server/layers/attention/flashinfer.py | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 7b5af3c4..f1469b3f 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -235,7 +235,6 @@ def attention( paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, - window_left=window_size_left, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, ) diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 909eea27..d2345184 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state( token = prefill_with_paged_kv_state.set(state) try: - state.begin_forward( + state.plan( qo_indptr=cu_seqlens, paged_kv_indptr=indptr, paged_kv_indices=block_tables, @@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state( ) yield finally: - state.end_forward() if token is not None: prefill_with_paged_kv_state.reset(token) @@ -200,7 +199,7 @@ def use_decode_state( token = decode_state.set(state) try: - state.begin_forward( + state.plan( indptr=indptr, indices=block_tables, last_page_len=last_page_len, @@ -214,6 +213,5 @@ def use_decode_state( ) yield finally: - state.end_forward() if token is not None: decode_state.reset(token)