From 642969522821a2ccf726198d38f3cbd272d4f1ed Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Tue, 1 Aug 2023 13:55:38 +0000 Subject: [PATCH] Skip top-n tokens in prefill --- .../models/causal_lm.py | 25 +++++++++++-------- .../models/flash_causal_lm.py | 22 +++++++++------- .../models/seq2seq_lm.py | 25 +++++++++++-------- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 929361e6..807c39de 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -606,15 +606,6 @@ class CausalLM(Model): top_token_ids, top_token_logprobs, ) in enumerate(iterator): - top_tokens = self.decode_top_tokens( - input_ids=all_input_ids.view(-1).tolist(), - top_n_tokens=top_n_tokens, - top_token_ids=top_token_ids, - top_token_logprobs=top_token_logprobs, - prefix_offset=prefix_offset, - read_offset=read_offset, - ) - # Select next token next_token_id, logprobs = next_token_chooser( all_input_ids.view(1, -1), logits[-1:, :] @@ -661,7 +652,8 @@ class CausalLM(Model): generated_text = None # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + prefill = stopping_criteria.current_tokens == 1 + if prefill and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 @@ -680,6 +672,19 @@ class CausalLM(Model): else: prefill_tokens = None + # Todo: Make optional for prefill + if not prefill and top_n_tokens > 0: + top_tokens = self.decode_top_tokens( + input_ids=all_input_ids[:-1].view(-1).tolist(), + top_n_tokens=top_n_tokens, + top_token_ids=top_token_ids, + top_token_logprobs=top_token_logprobs, + prefix_offset=prefix_offset, + read_offset=read_offset, + ) + else: + top_tokens = None + generation = Generation( request.id, prefill_tokens, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 672688c8..3fea3e0e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -982,15 +982,6 @@ class FlashCausalLM(Model): top_token_ids, top_token_logprobs, ) in enumerate(iterator): - top_tokens = self.decode_top_tokens( - input_ids=all_input_ids, - top_n_tokens=top_n_tokens, - top_token_ids=top_token_ids, - top_token_logprobs=top_token_logprobs, - prefix_offset=prefix_offset, - read_offset=read_offset, - ) - # Append next token to all tokens all_input_ids.append(next_token_id) @@ -1048,6 +1039,19 @@ class FlashCausalLM(Model): else: prefill_tokens = None + # Todo: Make optional for prefill + if not prefill and top_n_tokens > 0: + top_tokens = self.decode_top_tokens( + input_ids=all_input_ids[:-1], + top_n_tokens=top_n_tokens, + top_token_ids=top_token_ids, + top_token_logprobs=top_token_logprobs, + prefix_offset=prefix_offset, + read_offset=read_offset, + ) + else: + top_tokens = None + generation = Generation( request.id, prefill_tokens, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f5703ceb..096ab82f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -675,15 +675,6 @@ class Seq2SeqLM(Model): top_token_ids, top_token_logprobs, ) in enumerate(iterator): - top_tokens = self.decode_top_tokens( - input_ids=all_decoder_input_ids.view(-1).tolist(), - top_n_tokens=top_n_tokens, - top_token_ids=top_token_ids, - top_token_logprobs=top_token_logprobs, - prefix_offset=prefix_offset, - read_offset=read_offset, - ) - # Select next token next_token_id, logprobs = next_token_chooser( all_decoder_input_ids.view(1, -1), logits[-1:, :] @@ -731,7 +722,8 @@ class Seq2SeqLM(Model): generated_text = None # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + prefill = stopping_criteria.current_tokens == 1 + if prefill and request.prefill_logprobs: prefill_tokens = PrefillTokens( [self.tokenizer.bos_token_id], [float("nan")], @@ -740,6 +732,19 @@ class Seq2SeqLM(Model): else: prefill_tokens = None + # Todo: Make optional for prefill. How to implement in API? + if not prefill and top_n_tokens > 0: + top_tokens = self.decode_top_tokens( + input_ids=all_decoder_input_ids[:-1].view(-1).tolist(), + top_n_tokens=top_n_tokens, + top_token_ids=top_token_ids, + top_token_logprobs=top_token_logprobs, + prefix_offset=prefix_offset, + read_offset=read_offset, + ) + else: + top_tokens = None + generation = Generation( request.id, prefill_tokens,