mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
tokenized inputs from pb
This commit is contained in:
parent
a64c2a6f89
commit
9a0262f38c
@ -83,7 +83,7 @@ class CausalLMBatch(Batch):
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -303,7 +303,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max_length = max(max_length, input_length + max_new_tokens)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -634,6 +634,8 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser_parameters,
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
device=batches[0].next_token_chooser.device,
|
||||
# todo - determine how to obtain access to a tokenizer here
|
||||
tokenizer=...
|
||||
)
|
||||
|
||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||
|
@ -91,7 +91,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -92,7 +92,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
inputs.append(r.inputs)
|
||||
requests_idx_mapping[r.id] = i
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -33,7 +33,7 @@ class NextTokenChooser:
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
logit_bias=None,
|
||||
logit_bias={},
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -85,6 +85,7 @@ class NextTokenChooser:
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
watermark=pb.watermark,
|
||||
@ -96,7 +97,7 @@ class NextTokenChooser:
|
||||
do_sample=pb.do_sample,
|
||||
seed=pb.seed,
|
||||
device=device,
|
||||
logit_bias=pb.logit_bias,
|
||||
logit_bias=dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb.logit_bias]),
|
||||
)
|
||||
|
||||
|
||||
@ -285,6 +286,7 @@ bias
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
@ -297,7 +299,7 @@ bias
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
logit_bias={},
|
||||
logit_bias=[dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb_.logit_bias]) for pb_ in pb],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user