mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
flashinfer: fixup kv cache dtype
This commit is contained in:
parent
41948e240f
commit
4e6990267e
@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
dtype: torch.dtype,
|
kv_dtype: torch.dtype,
|
||||||
|
q_dtype: torch.dtype,
|
||||||
window_left: int,
|
window_left: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -91,7 +92,8 @@ def use_prefill_with_paged_kv_state(
|
|||||||
num_qo_heads=num_heads,
|
num_qo_heads=num_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
q_data_type=dtype,
|
kv_data_type=kv_dtype,
|
||||||
|
q_data_type=q_dtype,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
window_left=-1 if window_left is None else window_left,
|
window_left=-1 if window_left is None else window_left,
|
||||||
)
|
)
|
||||||
@ -113,41 +115,6 @@ def create_prefill_state(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def use_prefill_state(
|
|
||||||
*,
|
|
||||||
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
window_left: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Context manager to set the active flashinfer prefill state to the given
|
|
||||||
`state` and parameters. This state will be used by all calls to the
|
|
||||||
`attention` function while the context manager is active.
|
|
||||||
"""
|
|
||||||
|
|
||||||
token = prefill_state.set(state)
|
|
||||||
try:
|
|
||||||
state.begin_forward(
|
|
||||||
qo_indptr=cu_seqlens,
|
|
||||||
kv_indptr=cu_seqlens,
|
|
||||||
num_qo_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_dim=head_size,
|
|
||||||
q_data_type=dtype,
|
|
||||||
window_left=-1 if window_left is None else window_left,
|
|
||||||
)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
|
||||||
prefill_state.reset(token)
|
|
||||||
|
|
||||||
|
|
||||||
def create_decode_state(
|
def create_decode_state(
|
||||||
*,
|
*,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -205,7 +172,7 @@ def use_decode_state(
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
kv_cache_dtype: torch.dtype,
|
kv_cache_dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
q_dtype: torch.dtype,
|
||||||
window_left: int,
|
window_left: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -242,7 +209,7 @@ def use_decode_state(
|
|||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
data_type=kv_cache_dtype,
|
data_type=kv_cache_dtype,
|
||||||
q_data_type=dtype,
|
q_data_type=q_dtype,
|
||||||
window_left=-1 if window_left is None else window_left,
|
window_left=-1 if window_left is None else window_left,
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
|
@ -2480,7 +2480,8 @@ class FlashCausalLM(Model):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
dtype=self.dtype,
|
kv_dtype=self.kv_cache_dtype,
|
||||||
|
q_dtype=self.dtype,
|
||||||
window_left=self.sliding_window,
|
window_left=self.sliding_window,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -2494,6 +2495,6 @@ class FlashCausalLM(Model):
|
|||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
dtype=self.dtype,
|
q_dtype=self.dtype,
|
||||||
window_left=self.sliding_window,
|
window_left=self.sliding_window,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user