diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 870d261f..ab5bd6a6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -35,8 +35,8 @@ class CausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - offsets: List[Optional[int]] - token_offsets: List[Optional[int]] + offsets: List[int] + token_offsets: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -81,8 +81,8 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - offsets.append(None) - token_offsets.append(None) + # offsets.append(None) + # token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -102,6 +102,10 @@ class CausalLMBatch(Batch): truncation=True, max_length=max_truncation, ).to(device) + for i, r in enumerate(pb.requests): + input_len = tokenized_inputs["input_ids"].shape[1] + offsets.append(input_len) + token_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 657e4821..3fa50678 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -67,10 +67,20 @@ class Model(ABC): if read_offset is None: read_offset = 0 - prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset]) - new_text = self.tokenizer.decode(all_input_ids[prefix_offset:]) + # The prefix text is necessary only to defeat cleanup algorithms in the decode + # which decide to add a space or not depending on the surrounding ids. + prefix_text = self.tokenizer.decode( + all_input_ids[prefix_offset:read_offset], skip_special_tokens=False + ) + new_text = self.tokenizer.decode( + all_input_ids[prefix_offset:], skip_special_tokens=False + ) - if len(new_text) > len(prefix_text) and "�" not in new_text: + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model new_text = new_text[len(prefix_text) :] return new_text, read_offset, len(all_input_ids) else: