tokenized inputs from pb

This commit is contained in:
marcusdunn 2023-08-10 10:16:47 -07:00
parent a64c2a6f89
commit 9a0262f38c
5 changed files with 11 additions and 7 deletions

View File

@ -83,7 +83,7 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -303,7 +303,7 @@ class FlashCausalLMBatch(Batch):
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -634,6 +634,8 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
# todo - determine how to obtain access to a tokenizer here
tokenizer=...
)
# Needed to avoid dropping blocks when the batches will go out of scope

View File

@ -91,7 +91,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -92,7 +92,7 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -33,7 +33,7 @@ class NextTokenChooser:
do_sample=False,
seed=0,
device="cpu",
logit_bias=None,
logit_bias={},
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -85,6 +85,7 @@ class NextTokenChooser:
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
@ -96,7 +97,7 @@ class NextTokenChooser:
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
logit_bias=pb.logit_bias,
logit_bias=dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb.logit_bias]),
)
@ -285,6 +286,7 @@ bias
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
@ -297,7 +299,7 @@ bias
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
logit_bias={},
logit_bias=[dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb_.logit_bias]) for pb_ in pb],
)