diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5ff951b3..225133f5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + batch_inputs.append(r.inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + position_ids = [] cu_seqlens = [0] max_seqlen = 0 @@ -106,13 +116,11 @@ class FlashCausalLMBatch(Batch): max_length = 0 # Parse batch - for i, r in enumerate(pb.requests): + for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenizer( - r.inputs, truncation=True, max_length=r.truncate - )["input_ids"] + tokenized_input = tokenized_input[-r.truncate:] input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length)