Skip top-n tokens in prefill

This commit is contained in:
Vincent Brouwers 2023-08-01 13:55:38 +00:00
parent d16298b8d4
commit 730d86f1d0
3 changed files with 43 additions and 29 deletions

View File

@ -606,15 +606,6 @@ class CausalLM(Model):
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids.view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
@ -661,7 +652,8 @@ class CausalLM(Model):
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill = stopping_criteria.current_tokens == 1
if prefill and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
@ -680,6 +672,19 @@ class CausalLM(Model):
else:
prefill_tokens = None
# Todo: Make optional for prefill
if not prefill and top_n_tokens > 0:
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids[:-1].view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,

View File

@ -982,15 +982,6 @@ class FlashCausalLM(Model):
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids,
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Append next token to all tokens
all_input_ids.append(next_token_id)
@ -1048,6 +1039,19 @@ class FlashCausalLM(Model):
else:
prefill_tokens = None
# Todo: Make optional for prefill
if not prefill and top_n_tokens > 0:
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids[:-1],
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,

View File

@ -675,15 +675,6 @@ class Seq2SeqLM(Model):
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_decoder_input_ids.view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Select next token
next_token_id, logprobs = next_token_chooser(
all_decoder_input_ids.view(1, -1), logits[-1:, :]
@ -731,7 +722,8 @@ class Seq2SeqLM(Model):
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill = stopping_criteria.current_tokens == 1
if prefill and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id],
[float("nan")],
@ -740,6 +732,19 @@ class Seq2SeqLM(Model):
else:
prefill_tokens = None
# Todo: Make optional for prefill. How to implement in API?
if not prefill and top_n_tokens > 0:
top_tokens = self.decode_top_tokens(
input_ids=all_decoder_input_ids[:-1].view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,