mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Fixing initialization of token
, token_offset
.
This commit is contained in:
parent
1aa31bb5cc
commit
34e0a5b4a4
@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
offsets: List[int]
|
||||||
token_offsets: List[Optional[int]]
|
token_offsets: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -81,8 +81,8 @@ class CausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
offsets.append(None)
|
# offsets.append(None)
|
||||||
token_offsets.append(None)
|
# token_offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -102,6 +102,10 @@ class CausalLMBatch(Batch):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
for i, r in enumerate(pb.requests):
|
||||||
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
offsets.append(input_len)
|
||||||
|
token_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
|
@ -67,10 +67,20 @@ class Model(ABC):
|
|||||||
if read_offset is None:
|
if read_offset is None:
|
||||||
read_offset = 0
|
read_offset = 0
|
||||||
|
|
||||||
prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset])
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||||
new_text = self.tokenizer.decode(all_input_ids[prefix_offset:])
|
# which decide to add a space or not depending on the surrounding ids.
|
||||||
|
prefix_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
new_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[prefix_offset:], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
if len(new_text) > len(prefix_text) and "<EFBFBD>" not in new_text:
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
|
# from byte fallback tokenization.
|
||||||
|
# If it's in the middle, it's probably a real invalid id generated
|
||||||
|
# by the model
|
||||||
new_text = new_text[len(prefix_text) :]
|
new_text = new_text[len(prefix_text) :]
|
||||||
return new_text, read_offset, len(all_input_ids)
|
return new_text, read_offset, len(all_input_ids)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user