feat(server): batch tokenization for flash causal lm

This commit is contained in:
OlivierDehaene 2023-06-05 14:15:01 +02:00
parent 895c5f1562
commit 89c5621ecf

View File

@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = [] position_ids = []
cu_seqlens = [0] cu_seqlens = [0]
max_seqlen = 0 max_seqlen = 0
@ -106,13 +116,11 @@ class FlashCausalLMBatch(Batch):
max_length = 0 max_length = 0
# Parse batch # Parse batch
for i, r in enumerate(pb.requests): for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenizer( tokenized_input = tokenized_input[-r.truncate:]
r.inputs, truncation=True, max_length=r.truncate
)["input_ids"]
input_length = len(tokenized_input) input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)