mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Skip top-n tokens in prefill
This commit is contained in:
parent
d16298b8d4
commit
730d86f1d0
@ -606,15 +606,6 @@ class CausalLM(Model):
|
|||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
top_tokens = self.decode_top_tokens(
|
|
||||||
input_ids=all_input_ids.view(-1).tolist(),
|
|
||||||
top_n_tokens=top_n_tokens,
|
|
||||||
top_token_ids=top_token_ids,
|
|
||||||
top_token_logprobs=top_token_logprobs,
|
|
||||||
prefix_offset=prefix_offset,
|
|
||||||
read_offset=read_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_input_ids.view(1, -1), logits[-1:, :]
|
all_input_ids.view(1, -1), logits[-1:, :]
|
||||||
@ -661,7 +652,8 @@ class CausalLM(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
prefill = stopping_criteria.current_tokens == 1
|
||||||
|
if prefill and request.prefill_logprobs:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||||
logits, -1
|
logits, -1
|
||||||
@ -680,6 +672,19 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
|
# Todo: Make optional for prefill
|
||||||
|
if not prefill and top_n_tokens > 0:
|
||||||
|
top_tokens = self.decode_top_tokens(
|
||||||
|
input_ids=all_input_ids[:-1].view(-1).tolist(),
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_token_ids=top_token_ids,
|
||||||
|
top_token_logprobs=top_token_logprobs,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
|
@ -982,15 +982,6 @@ class FlashCausalLM(Model):
|
|||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
top_tokens = self.decode_top_tokens(
|
|
||||||
input_ids=all_input_ids,
|
|
||||||
top_n_tokens=top_n_tokens,
|
|
||||||
top_token_ids=top_token_ids,
|
|
||||||
top_token_logprobs=top_token_logprobs,
|
|
||||||
prefix_offset=prefix_offset,
|
|
||||||
read_offset=read_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
|
|
||||||
@ -1048,6 +1039,19 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
|
# Todo: Make optional for prefill
|
||||||
|
if not prefill and top_n_tokens > 0:
|
||||||
|
top_tokens = self.decode_top_tokens(
|
||||||
|
input_ids=all_input_ids[:-1],
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_token_ids=top_token_ids,
|
||||||
|
top_token_logprobs=top_token_logprobs,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
|
@ -675,15 +675,6 @@ class Seq2SeqLM(Model):
|
|||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
top_tokens = self.decode_top_tokens(
|
|
||||||
input_ids=all_decoder_input_ids.view(-1).tolist(),
|
|
||||||
top_n_tokens=top_n_tokens,
|
|
||||||
top_token_ids=top_token_ids,
|
|
||||||
top_token_logprobs=top_token_logprobs,
|
|
||||||
prefix_offset=prefix_offset,
|
|
||||||
read_offset=read_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||||
@ -731,7 +722,8 @@ class Seq2SeqLM(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
prefill = stopping_criteria.current_tokens == 1
|
||||||
|
if prefill and request.prefill_logprobs:
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(
|
||||||
[self.tokenizer.bos_token_id],
|
[self.tokenizer.bos_token_id],
|
||||||
[float("nan")],
|
[float("nan")],
|
||||||
@ -740,6 +732,19 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
|
# Todo: Make optional for prefill. How to implement in API?
|
||||||
|
if not prefill and top_n_tokens > 0:
|
||||||
|
top_tokens = self.decode_top_tokens(
|
||||||
|
input_ids=all_decoder_input_ids[:-1].view(-1).tolist(),
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_token_ids=top_token_ids,
|
||||||
|
top_token_logprobs=top_token_logprobs,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user