flashinfer: fixup kv cache dtype

This commit is contained in:
Daniël de Kok 2025-01-08 12:21:50 +00:00
parent 41948e240f
commit 4e6990267e
2 changed files with 9 additions and 41 deletions

View File

@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
num_kv_heads: int,
head_size: int,
page_size: int,
dtype: torch.dtype,
kv_dtype: torch.dtype,
q_dtype: torch.dtype,
window_left: int,
):
"""
@ -91,7 +92,8 @@ def use_prefill_with_paged_kv_state(
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
kv_data_type=kv_dtype,
q_data_type=q_dtype,
page_size=page_size,
window_left=-1 if window_left is None else window_left,
)
@ -113,41 +115,6 @@ def create_prefill_state(
)
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
window_left=-1 if window_left is None else window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state(
*,
device: torch.device,
@ -205,7 +172,7 @@ def use_decode_state(
head_size: int,
page_size: int,
kv_cache_dtype: torch.dtype,
dtype: torch.dtype,
q_dtype: torch.dtype,
window_left: int,
):
"""
@ -242,7 +209,7 @@ def use_decode_state(
head_dim=head_size,
page_size=page_size,
data_type=kv_cache_dtype,
q_data_type=dtype,
q_data_type=q_dtype,
window_left=-1 if window_left is None else window_left,
)
yield

View File

@ -2480,7 +2480,8 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
dtype=self.dtype,
kv_dtype=self.kv_cache_dtype,
q_dtype=self.dtype,
window_left=self.sliding_window,
)
else:
@ -2494,6 +2495,6 @@ class FlashCausalLM(Model):
head_size=self.head_size,
page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype,
q_dtype=self.dtype,
window_left=self.sliding_window,
)