mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fixing initialization of token
, token_offset
.
This commit is contained in:
parent
1aa31bb5cc
commit
34e0a5b4a4
@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
offsets: List[Optional[int]]
|
||||
token_offsets: List[Optional[int]]
|
||||
offsets: List[int]
|
||||
token_offsets: List[int]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
@ -81,8 +81,8 @@ class CausalLMBatch(Batch):
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
# offsets.append(None)
|
||||
# token_offsets.append(None)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
@ -102,6 +102,10 @@ class CausalLMBatch(Batch):
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
).to(device)
|
||||
for i, r in enumerate(pb.requests):
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
offsets.append(input_len)
|
||||
token_offsets.append(input_len)
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
max_input_length = input_lengths.max()
|
||||
|
@ -67,10 +67,20 @@ class Model(ABC):
|
||||
if read_offset is None:
|
||||
read_offset = 0
|
||||
|
||||
prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset])
|
||||
new_text = self.tokenizer.decode(all_input_ids[prefix_offset:])
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||
# which decide to add a space or not depending on the surrounding ids.
|
||||
prefix_text = self.tokenizer.decode(
|
||||
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||
)
|
||||
new_text = self.tokenizer.decode(
|
||||
all_input_ids[prefix_offset:], skip_special_tokens=False
|
||||
)
|
||||
|
||||
if len(new_text) > len(prefix_text) and "<EFBFBD>" not in new_text:
|
||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
new_text = new_text[len(prefix_text) :]
|
||||
return new_text, read_offset, len(all_input_ids)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user