mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Skip top-n tokens in prefill
This commit is contained in:
parent
d16298b8d4
commit
730d86f1d0
@ -606,15 +606,6 @@ class CausalLM(Model):
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids.view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
@ -661,7 +652,8 @@ class CausalLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
prefill = stopping_criteria.current_tokens == 1
|
||||
if prefill and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
@ -680,6 +672,19 @@ class CausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
# Todo: Make optional for prefill
|
||||
if not prefill and top_n_tokens > 0:
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids[:-1].view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
@ -982,15 +982,6 @@ class FlashCausalLM(Model):
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
|
||||
@ -1048,6 +1039,19 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
# Todo: Make optional for prefill
|
||||
if not prefill and top_n_tokens > 0:
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids[:-1],
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
@ -675,15 +675,6 @@ class Seq2SeqLM(Model):
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_decoder_input_ids.view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||
@ -731,7 +722,8 @@ class Seq2SeqLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
prefill = stopping_criteria.current_tokens == 1
|
||||
if prefill and request.prefill_logprobs:
|
||||
prefill_tokens = PrefillTokens(
|
||||
[self.tokenizer.bos_token_id],
|
||||
[float("nan")],
|
||||
@ -740,6 +732,19 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
# Todo: Make optional for prefill. How to implement in API?
|
||||
if not prefill and top_n_tokens > 0:
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_decoder_input_ids[:-1].view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user