Fixing top_n_tokens.

This commit is contained in:
Nicolas Patry 2024-01-26 16:38:03 +00:00
parent d9758851be
commit 6e629add98
3 changed files with 63 additions and 36 deletions

View File

@ -842,6 +842,8 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
speculate = get_speculate()
( (
next_input_ids, next_input_ids,
next_token_logprobs, next_token_logprobs,
@ -851,16 +853,16 @@ class FlashCausalLM(Model):
) = batch.next_token_chooser( ) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits, next_token_logits,
get_speculate(), speculate,
batch.speculative_ids, batch.speculative_ids,
speculative_logits, speculative_logits,
) )
speculated_length = batch.speculative_ids.shape[-1] if batch.speculative_ids is not None else 0
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids, speculated_length
) )
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -1062,6 +1064,8 @@ class FlashCausalLM(Model):
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
@ -1076,6 +1080,8 @@ class FlashCausalLM(Model):
toptoken_texts, toptoken_texts,
special_toptokens, special_toptokens,
) )
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else: else:
top_tokens = None top_tokens = None

View File

@ -95,5 +95,5 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
) )

View File

@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser:
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S) next_ids = next_ids.view(B * S)
scores = scores.view(B * S, -1) allscores = scores.view(B * S, -1)
alllogprobs = torch.log_softmax(allscores, -1)
if speculated_ids is not None: if speculated_ids is not None:
accepted_ids = [] accepted_ids = []
@ -305,15 +306,14 @@ class HeterogeneousNextTokenChooser:
accepted_ids, device=input_ids.device, dtype=input_ids.dtype accepted_ids, device=input_ids.device, dtype=input_ids.dtype
) )
next_ids = next_ids[indices] next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None: if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1] speculative_scores = speculative_scores[indices + accepted_ids - 1]
else: else:
accepted_ids = torch.ones_like(next_ids) accepted_ids = torch.ones_like(next_ids)
logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0: if speculate > 0:
if speculative_scores is not None: if speculative_scores is not None:
@ -327,7 +327,7 @@ class HeterogeneousNextTokenChooser:
else: else:
speculative_ids = None speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -436,7 +436,7 @@ class HeterogeneousSampling:
def batch_top_tokens( def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor, speculative_length: int
) -> Tuple[List[List[int]], List[List[float]]]: ) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations. """Find the top n most likely tokens for a batch of generations.
@ -448,12 +448,16 @@ def batch_top_tokens(
if max_top_n == 0: if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
n = speculative_length + 1
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(n)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens] top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculative_length + 1)]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset # Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
nth_highest = torch.gather( nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1) sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
) )
@ -471,13 +475,30 @@ def batch_top_tokens(
top_indices = top_k.indices.tolist() top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist() top_values = top_k.values.tolist()
return ( batch_top_token_ids = []
[ batch_top_token_logprobs = []
idxs[:n] if req_n > 0 else [] accepted_ids = accepted_ids.tolist()
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens) for i, n_accepted_ids in enumerate(accepted_ids):
], _top_indices = top_indices[n * i: n * (i + 1)]
[ _top_values = top_values[n * i: n * (i + 1)]
vals[:n] if req_n > 0 else [] _top_n_ishes = top_n_ishes[n * i: n * (i + 1)]
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens) _top_n_tokens = top_n_tokens[n * i: n * (i + 1)]
], _top_indices = _top_indices[:n_accepted_ids]
) _top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []
row_top_token_ids.append(indices)
row_top_token_logprobs.append(values)
batch_top_token_ids.append(row_top_token_ids)
batch_top_token_logprobs.append(row_top_token_logprobs)
return batch_top_token_ids, batch_top_token_logprobs