Better prefix for edge cases.

This commit is contained in:
Nicolas Patry 2023-05-16 12:31:00 +02:00 committed by OlivierDehaene
parent 34e0a5b4a4
commit 8ddbdea45b
2 changed files with 8 additions and 5 deletions

View File

@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
).to(device)
for i, r in enumerate(pb.requests):
input_len = tokenized_inputs["input_ids"].shape[1]
offsets.append(input_len)
offsets.append(0)
token_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)

View File

@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
decoder_input_lengths: List[int]
offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
offsets: List[int]
token_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@ -91,8 +91,8 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
offsets.append(None)
token_offsets.append(None)
# offsets.append(None)
# token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@ -123,6 +123,9 @@ class Seq2SeqLMBatch(Batch):
.repeat(len(pb.requests))
.view(-1, 1)
)
for i, r in enumerate(pb.requests):
offsets.append(0)
token_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens