diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ab5bd6a6..7dc3e0aa 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -104,7 +104,7 @@ class CausalLMBatch(Batch): ).to(device) for i, r in enumerate(pb.requests): input_len = tokenized_inputs["input_ids"].shape[1] - offsets.append(input_len) + offsets.append(0) token_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 98be4c71..42201e9f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] - offsets: List[Optional[int]] - token_offsets: List[Optional[int]] + offsets: List[int] + token_offsets: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -91,8 +91,8 @@ class Seq2SeqLMBatch(Batch): inputs.append(r.inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - offsets.append(None) - token_offsets.append(None) + # offsets.append(None) + # token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -123,6 +123,9 @@ class Seq2SeqLMBatch(Batch): .repeat(len(pb.requests)) .view(-1, 1) ) + for i, r in enumerate(pb.requests): + offsets.append(0) + token_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) max_tokens = len(inputs) * max_input_length + max_decode_tokens