mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Allocate top_n_token tensor in Batch
This commit is contained in:
parent
65c0d9c19d
commit
1b5fdf7000
@ -44,6 +44,7 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
top_n_tokens: List[int]
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
@ -125,6 +126,7 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
@ -143,6 +145,7 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -230,6 +233,7 @@ class CausalLMBatch(Batch):
|
|||||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
@ -243,6 +247,7 @@ class CausalLMBatch(Batch):
|
|||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
self.top_n_tokens = top_n_tokens
|
self.top_n_tokens = top_n_tokens
|
||||||
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
@ -278,6 +283,7 @@ class CausalLMBatch(Batch):
|
|||||||
attention_mask = None
|
attention_mask = None
|
||||||
position_ids = None
|
position_ids = None
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
top_n_tokens_tensor = None
|
||||||
|
|
||||||
# Used for slicing correctly inside the tensors
|
# Used for slicing correctly inside the tensors
|
||||||
# Equivalent to a cumsum on batch sizes
|
# Equivalent to a cumsum on batch sizes
|
||||||
@ -320,6 +326,12 @@ class CausalLMBatch(Batch):
|
|||||||
(total_batch_size, max_input_length + padding_right_offset),
|
(total_batch_size, max_input_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
# and to remove unused allocated space
|
# and to remove unused allocated space
|
||||||
left_offset = max_input_length - batch.max_input_length
|
left_offset = max_input_length - batch.max_input_length
|
||||||
@ -449,6 +461,7 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
@ -561,7 +574,7 @@ class CausalLM(Model):
|
|||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
@ -594,7 +607,7 @@ class CausalLM(Model):
|
|||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
top_tokens = self.decode_top_tokens(
|
top_tokens = self.decode_top_tokens(
|
||||||
input_ids=all_input_ids.view(1, -1).tolist(),
|
input_ids=all_input_ids.view(-1).tolist(),
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_token_ids=top_token_ids,
|
top_token_ids=top_token_ids,
|
||||||
top_token_logprobs=top_token_logprobs,
|
top_token_logprobs=top_token_logprobs,
|
||||||
|
@ -168,6 +168,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser: HeterogeneousNextTokenChooser
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
top_n_tokens: List[int]
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
# Number of blocks in this batch
|
# Number of blocks in this batch
|
||||||
blocks: int
|
blocks: int
|
||||||
@ -357,6 +358,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_next_token_indices = torch.tensor(
|
prefill_next_token_indices = torch.tensor(
|
||||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -384,6 +386,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -496,6 +499,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -528,6 +532,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -576,6 +581,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
@ -613,6 +621,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
@ -680,6 +689,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -850,7 +860,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
|
@ -108,14 +108,15 @@ class Model(ABC):
|
|||||||
new_sequences, skip_special_tokens=False
|
new_sequences, skip_special_tokens=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefix_len = len(prefix_text)
|
||||||
results = []
|
results = []
|
||||||
for new_text in new_texts:
|
for new_text in new_texts:
|
||||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
|
||||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
# from byte fallback tokenization.
|
# from byte fallback tokenization.
|
||||||
# If it's in the middle, it's probably a real invalid id generated
|
# If it's in the middle, it's probably a real invalid id generated
|
||||||
# by the model
|
# by the model
|
||||||
new_text = new_text[len(prefix_text) :]
|
new_text = new_text[prefix_len:]
|
||||||
results.append((new_text, read_offset, len(input_ids) + 1))
|
results.append((new_text, read_offset, len(input_ids) + 1))
|
||||||
else:
|
else:
|
||||||
results.append(("", prefix_offset, read_offset))
|
results.append(("", prefix_offset, read_offset))
|
||||||
|
@ -50,6 +50,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
top_n_tokens: List[int]
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
@ -129,6 +130,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
prefix_offsets.append(0)
|
prefix_offsets.append(0)
|
||||||
read_offsets.append(1)
|
read_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
@ -150,6 +152,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
max_decoder_input_length=1,
|
max_decoder_input_length=1,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
@ -245,6 +248,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||||
|
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
len(request_ids) * (max_input_length + max_decoder_input_length)
|
len(request_ids) * (max_input_length + max_decoder_input_length)
|
||||||
+ remaining_decode_tokens
|
+ remaining_decode_tokens
|
||||||
@ -261,6 +265,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
self.top_n_tokens = top_n_tokens
|
self.top_n_tokens = top_n_tokens
|
||||||
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.max_decoder_input_length = max_decoder_input_length
|
self.max_decoder_input_length = max_decoder_input_length
|
||||||
self.padding_right_offset = padding_right_offset
|
self.padding_right_offset = padding_right_offset
|
||||||
@ -304,6 +309,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
decoder_input_ids = None
|
decoder_input_ids = None
|
||||||
decoder_attention_mask = None
|
decoder_attention_mask = None
|
||||||
encoder_last_hidden_state = None
|
encoder_last_hidden_state = None
|
||||||
|
top_n_tokens_tensor = None
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
|
||||||
# Used for slicing correctly inside the tensors
|
# Used for slicing correctly inside the tensors
|
||||||
@ -393,6 +399,12 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
|
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
encoder_last_hidden_state[
|
encoder_last_hidden_state[
|
||||||
start_index:end_index, -batch.max_input_length :, :
|
start_index:end_index, -batch.max_input_length :, :
|
||||||
@ -498,6 +510,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
@ -624,7 +637,7 @@ class Seq2SeqLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
@ -663,7 +676,7 @@ class Seq2SeqLM(Model):
|
|||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
top_tokens = self.decode_top_tokens(
|
top_tokens = self.decode_top_tokens(
|
||||||
input_ids=all_decoder_input_ids.view(1, -1).tolist(),
|
input_ids=all_decoder_input_ids.view(-1).tolist(),
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_token_ids=top_token_ids,
|
top_token_ids=top_token_ids,
|
||||||
top_token_logprobs=top_token_logprobs,
|
top_token_logprobs=top_token_logprobs,
|
||||||
|
@ -337,20 +337,16 @@ class HeterogeneousSampling:
|
|||||||
|
|
||||||
|
|
||||||
def batch_top_tokens(
|
def batch_top_tokens(
|
||||||
top_n_tokens: list[int], logprobs: torch.Tensor
|
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
||||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||||
"""Find the top n most likely tokens for a batch of generations.
|
"""Find the top n most likely tokens for a batch of generations.
|
||||||
|
|
||||||
When multiple tokens have equal probabilities and they don't all fit, the
|
When multiple tokens have equal probabilities and they don't all fit, the
|
||||||
remaining tokens are also returned.
|
remaining tokens are also returned.
|
||||||
"""
|
"""
|
||||||
# Do this as early as possible to mitigate copy latency
|
max_top_n = max(top_n_tokens)
|
||||||
top_n_tensor = torch.tensor(top_n_tokens).to(
|
|
||||||
device=logprobs.device, non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Early exit when top_n_tokens is not used
|
# Early exit when top_n_tokens is not used
|
||||||
if max(top_n_tokens) == 0:
|
if max_top_n == 0:
|
||||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||||
|
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
@ -358,11 +354,9 @@ def batch_top_tokens(
|
|||||||
|
|
||||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
sorted_top_k = torch.topk(
|
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
||||||
logprobs, k=max(top_n_tokens), dim=1, sorted=True
|
|
||||||
).values # .cpu()
|
|
||||||
nth_highest = torch.gather(
|
nth_highest = torch.gather(
|
||||||
sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1)
|
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||||
)
|
)
|
||||||
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user