mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Better prefix for edge cases.
This commit is contained in:
parent
34e0a5b4a4
commit
8ddbdea45b
@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
|
||||
).to(device)
|
||||
for i, r in enumerate(pb.requests):
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
offsets.append(input_len)
|
||||
offsets.append(0)
|
||||
token_offsets.append(input_len)
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
|
@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
decoder_input_lengths: List[int]
|
||||
offsets: List[Optional[int]]
|
||||
token_offsets: List[Optional[int]]
|
||||
offsets: List[int]
|
||||
token_offsets: List[int]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
@ -91,8 +91,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
inputs.append(r.inputs)
|
||||
requests_idx_mapping[r.id] = i
|
||||
decoder_input_lengths.append(1)
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
# offsets.append(None)
|
||||
# token_offsets.append(None)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
@ -123,6 +123,9 @@ class Seq2SeqLMBatch(Batch):
|
||||
.repeat(len(pb.requests))
|
||||
.view(-1, 1)
|
||||
)
|
||||
for i, r in enumerate(pb.requests):
|
||||
offsets.append(0)
|
||||
token_offsets.append(1)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
|
Loading…
Reference in New Issue
Block a user