From 9a0262f38ccf634d212600e78c83202fccb297cf Mon Sep 17 00:00:00 2001 From: marcusdunn Date: Thu, 10 Aug 2023 10:16:47 -0700 Subject: [PATCH] tokenized inputs from pb --- server/text_generation_server/models/causal_lm.py | 2 +- server/text_generation_server/models/flash_causal_lm.py | 4 +++- server/text_generation_server/models/galactica.py | 2 +- server/text_generation_server/models/seq2seq_lm.py | 2 +- server/text_generation_server/utils/tokens.py | 8 +++++--- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cbdf4808..2eff73d4 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -83,7 +83,7 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7de51358..860c9bc5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -303,7 +303,7 @@ class FlashCausalLMBatch(Batch): max_length = max(max_length, input_length + max_new_tokens) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -634,6 +634,8 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, + # todo - determine how to obtain access to a tokenizer here + tokenizer=... ) # Needed to avoid dropping blocks when the batches will go out of scope diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index d4211734..3658693b 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -91,7 +91,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 9e5c21d1..9cffad48 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -92,7 +92,7 @@ class Seq2SeqLMBatch(Batch): inputs.append(r.inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index f99405ad..1368f47d 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -33,7 +33,7 @@ class NextTokenChooser: do_sample=False, seed=0, device="cpu", - logit_bias=None, + logit_bias={}, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -85,6 +85,7 @@ class NextTokenChooser: cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -96,7 +97,7 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, - logit_bias=pb.logit_bias, + logit_bias=dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb.logit_bias]), ) @@ -285,6 +286,7 @@ bias pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -297,7 +299,7 @@ bias seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, - logit_bias={}, + logit_bias=[dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb_.logit_bias]) for pb_ in pb], )