mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Better prefix for edge cases.
This commit is contained in:
parent
34e0a5b4a4
commit
8ddbdea45b
@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
|
|||||||
).to(device)
|
).to(device)
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
offsets.append(input_len)
|
offsets.append(0)
|
||||||
token_offsets.append(input_len)
|
token_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
offsets: List[int]
|
||||||
token_offsets: List[Optional[int]]
|
token_offsets: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -91,8 +91,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
offsets.append(None)
|
# offsets.append(None)
|
||||||
token_offsets.append(None)
|
# token_offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -123,6 +123,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
.repeat(len(pb.requests))
|
.repeat(len(pb.requests))
|
||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
)
|
)
|
||||||
|
for i, r in enumerate(pb.requests):
|
||||||
|
offsets.append(0)
|
||||||
|
token_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user