Fixing top_n_tokens.

This commit is contained in:
Nicolas Patry 2024-01-26 16:38:03 +00:00
parent d9758851be
commit 6e629add98
3 changed files with 63 additions and 36 deletions

View File

@ -842,6 +842,8 @@ class FlashCausalLM(Model):
else:
next_token_logits = out
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
@ -851,16 +853,16 @@ class FlashCausalLM(Model):
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits,
get_speculate(),
speculate,
batch.speculative_ids,
speculative_logits,
)
speculated_length = batch.speculative_ids.shape[-1] if batch.speculative_ids is not None else 0
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids, speculated_length
)
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -1062,6 +1064,8 @@ class FlashCausalLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
@ -1076,6 +1080,8 @@ class FlashCausalLM(Model):
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -95,5 +95,5 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
)

View File

@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser:
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S)
scores = scores.view(B * S, -1)
allscores = scores.view(B * S, -1)
alllogprobs = torch.log_softmax(allscores, -1)
if speculated_ids is not None:
accepted_ids = []
@ -305,15 +306,14 @@ class HeterogeneousNextTokenChooser:
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
if speculative_scores is not None:
@ -327,7 +327,7 @@ class HeterogeneousNextTokenChooser:
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def filter(self, indices):
if self.watermark_processor is not None:
@ -436,7 +436,7 @@ class HeterogeneousSampling:
def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor, speculative_length: int
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
@ -448,12 +448,16 @@ def batch_top_tokens(
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
n = speculative_length + 1
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(n)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculative_length + 1)]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
@ -471,13 +475,30 @@ def batch_top_tokens(
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)
batch_top_token_ids = []
batch_top_token_logprobs = []
accepted_ids = accepted_ids.tolist()
for i, n_accepted_ids in enumerate(accepted_ids):
_top_indices = top_indices[n * i: n * (i + 1)]
_top_values = top_values[n * i: n * (i + 1)]
_top_n_ishes = top_n_ishes[n * i: n * (i + 1)]
_top_n_tokens = top_n_tokens[n * i: n * (i + 1)]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []
row_top_token_ids.append(indices)
row_top_token_logprobs.append(values)
batch_top_token_ids.append(row_top_token_ids)
batch_top_token_logprobs.append(row_top_token_logprobs)
return batch_top_token_ids, batch_top_token_logprobs