mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Update?
This commit is contained in:
parent
fcbc6876c0
commit
212a59544b
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
@ -62,7 +62,8 @@ def paged_attention(
|
|||||||
#
|
#
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = value_cache.shape[3]
|
# block_size = value_cache.shape[3]
|
||||||
|
block_size = BLOCK_SIZE
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = cu_seqlen_k
|
input_lengths = cu_seqlen_k
|
||||||
@ -87,7 +88,7 @@ def paged_attention(
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
None,
|
None,
|
||||||
cu_seqlen_k,
|
cu_seqlen_q,
|
||||||
cu_seqlen_k,
|
cu_seqlen_k,
|
||||||
None,
|
None,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
Loading…
Reference in New Issue
Block a user