Fixing initialization of token, token_offset.

This commit is contained in:
Nicolas Patry 2023-05-16 12:14:36 +02:00 committed by OlivierDehaene
parent 1aa31bb5cc
commit 34e0a5b4a4
2 changed files with 21 additions and 7 deletions

View File

@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
offsets: List[int]
token_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@ -81,8 +81,8 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
offsets.append(None)
token_offsets.append(None)
# offsets.append(None)
# token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@ -102,6 +102,10 @@ class CausalLMBatch(Batch):
truncation=True,
max_length=max_truncation,
).to(device)
for i, r in enumerate(pb.requests):
input_len = tokenized_inputs["input_ids"].shape[1]
offsets.append(input_len)
token_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()

View File

@ -67,10 +67,20 @@ class Model(ABC):
if read_offset is None:
read_offset = 0
prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset])
new_text = self.tokenizer.decode(all_input_ids[prefix_offset:])
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
)
new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=False
)
if len(new_text) > len(prefix_text) and "<EFBFBD>" not in new_text:
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text) :]
return new_text, read_offset, len(all_input_ids)
else: