From 4e6990267e3000e719760dd893105fdae4195a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 8 Jan 2025 12:21:50 +0000 Subject: [PATCH] flashinfer: fixup kv cache dtype --- .../layers/attention/flashinfer.py | 45 +++---------------- .../models/flash_causal_lm.py | 5 ++- 2 files changed, 9 insertions(+), 41 deletions(-) diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index ea1bc1d7..909eea27 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state( num_kv_heads: int, head_size: int, page_size: int, - dtype: torch.dtype, + kv_dtype: torch.dtype, + q_dtype: torch.dtype, window_left: int, ): """ @@ -91,7 +92,8 @@ def use_prefill_with_paged_kv_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=dtype, + kv_data_type=kv_dtype, + q_data_type=q_dtype, page_size=page_size, window_left=-1 if window_left is None else window_left, ) @@ -113,41 +115,6 @@ def create_prefill_state( ) -@contextmanager -def use_prefill_state( - *, - state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, - cu_seqlens: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - token = prefill_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - kv_indptr=cu_seqlens, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - window_left=-1 if window_left is None else window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_state.reset(token) - - def create_decode_state( *, device: torch.device, @@ -205,7 +172,7 @@ def use_decode_state( head_size: int, page_size: int, kv_cache_dtype: torch.dtype, - dtype: torch.dtype, + q_dtype: torch.dtype, window_left: int, ): """ @@ -242,7 +209,7 @@ def use_decode_state( head_dim=head_size, page_size=page_size, data_type=kv_cache_dtype, - q_data_type=dtype, + q_data_type=q_dtype, window_left=-1 if window_left is None else window_left, ) yield diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5d376990..c63ca1db 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2480,7 +2480,8 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, - dtype=self.dtype, + kv_dtype=self.kv_cache_dtype, + q_dtype=self.dtype, window_left=self.sliding_window, ) else: @@ -2494,6 +2495,6 @@ class FlashCausalLM(Model): head_size=self.head_size, page_size=BLOCK_SIZE, kv_cache_dtype=self.kv_cache_dtype, - dtype=self.dtype, + q_dtype=self.dtype, window_left=self.sliding_window, )