From 89c5621ecfa0f26e9a5d2ce775a94a7a78b963f8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 5 Jun 2023 14:15:01 +0200 Subject: [PATCH] feat(server): batch tokenization for flash causal lm --- .../models/flash_causal_lm.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5ff951b3..225133f5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + batch_inputs.append(r.inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + position_ids = [] cu_seqlens = [0] max_seqlen = 0 @@ -106,13 +116,11 @@ class FlashCausalLMBatch(Batch): max_length = 0 # Parse batch - for i, r in enumerate(pb.requests): + for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenizer( - r.inputs, truncation=True, max_length=r.truncate - )["input_ids"] + tokenized_input = tokenized_input[-r.truncate:] input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length)