mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
flashinfer: fixup kv cache dtype
This commit is contained in:
parent
41948e240f
commit
4e6990267e
@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_dtype: torch.dtype,
|
||||
q_dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
@ -91,7 +92,8 @@ def use_prefill_with_paged_kv_state(
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_dtype,
|
||||
q_data_type=q_dtype,
|
||||
page_size=page_size,
|
||||
window_left=-1 if window_left is None else window_left,
|
||||
)
|
||||
@ -113,41 +115,6 @@ def create_prefill_state(
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_prefill_state(
|
||||
*,
|
||||
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
||||
cu_seqlens: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`attention` function while the context manager is active.
|
||||
"""
|
||||
|
||||
token = prefill_state.set(state)
|
||||
try:
|
||||
state.begin_forward(
|
||||
qo_indptr=cu_seqlens,
|
||||
kv_indptr=cu_seqlens,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=dtype,
|
||||
window_left=-1 if window_left is None else window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
prefill_state.reset(token)
|
||||
|
||||
|
||||
def create_decode_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
@ -205,7 +172,7 @@ def use_decode_state(
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
q_dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
@ -242,7 +209,7 @@ def use_decode_state(
|
||||
head_dim=head_size,
|
||||
page_size=page_size,
|
||||
data_type=kv_cache_dtype,
|
||||
q_data_type=dtype,
|
||||
q_data_type=q_dtype,
|
||||
window_left=-1 if window_left is None else window_left,
|
||||
)
|
||||
yield
|
||||
|
@ -2480,7 +2480,8 @@ class FlashCausalLM(Model):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
dtype=self.dtype,
|
||||
kv_dtype=self.kv_cache_dtype,
|
||||
q_dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
@ -2494,6 +2495,6 @@ class FlashCausalLM(Model):
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
dtype=self.dtype,
|
||||
q_dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user