mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(server): batch tokenization for flash causal lm
This commit is contained in:
parent
895c5f1562
commit
89c5621ecf
@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlens = [0]
|
||||
max_seqlen = 0
|
||||
@ -106,13 +116,11 @@ class FlashCausalLMBatch(Batch):
|
||||
max_length = 0
|
||||
|
||||
# Parse batch
|
||||
for i, r in enumerate(pb.requests):
|
||||
for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenizer(
|
||||
r.inputs, truncation=True, max_length=r.truncate
|
||||
)["input_ids"]
|
||||
tokenized_input = tokenized_input[-r.truncate:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
|
Loading…
Reference in New Issue
Block a user