mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
feat: paged attention v2 (#1183)
This commit is contained in:
parent
63fa534612
commit
12590fdcce
@ -1,4 +1,4 @@
|
|||||||
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
|
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||||
|
|
||||||
flash-attention-v2:
|
flash-attention-v2:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
|
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
git clone https://github.com/OlivierDehaene/vllm.git
|
git clone https://github.com/vllm-project/vllm.git
|
||||||
|
|
||||||
build-vllm: vllm
|
build-vllm: vllm
|
||||||
cd vllm && git fetch && git checkout $(vllm_commit)
|
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||||
|
@ -29,11 +29,7 @@ from typing import Optional, List, Tuple
|
|||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,10 +29,7 @@ from typing import Optional, List, Tuple
|
|||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, :, 0].contiguous(),
|
kv[:, :, 0].contiguous(),
|
||||||
kv[:, :, 1].contiguous(),
|
kv[:, :, 1].contiguous(),
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5,10 +5,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
import vllm_cache_ops
|
|
||||||
import vllm_attention_ops
|
|
||||||
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -18,7 +15,6 @@ from text_generation_server.utils.layers import (
|
|||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from safetensors import SafetensorError
|
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
|
paged_attention.attention(
|
||||||
block_size = kv_cache[1].shape[3]
|
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
100
server/text_generation_server/utils/paged_attention.py
Normal file
100
server/text_generation_server/utils/paged_attention.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
from vllm import cache_ops
|
||||||
|
from vllm import attention_ops
|
||||||
|
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor):
|
||||||
|
cache_ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
# reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = (
|
||||||
|
(max_s + _PARTITION_SIZE - 1) //
|
||||||
|
_PARTITION_SIZE)
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
|
if use_v1:
|
||||||
|
attention_ops.paged_attention_v1(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=out.dtype,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
attention_ops.paged_attention_v2(
|
||||||
|
out,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user