mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing top_n_tokens.
This commit is contained in:
parent
d9758851be
commit
6e629add98
@ -842,6 +842,8 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
|
|
||||||
|
speculate = get_speculate()
|
||||||
(
|
(
|
||||||
next_input_ids,
|
next_input_ids,
|
||||||
next_token_logprobs,
|
next_token_logprobs,
|
||||||
@ -851,16 +853,16 @@ class FlashCausalLM(Model):
|
|||||||
) = batch.next_token_chooser(
|
) = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen],
|
batch.all_input_ids_tensor[:, : batch.max_seqlen],
|
||||||
next_token_logits,
|
next_token_logits,
|
||||||
get_speculate(),
|
speculate,
|
||||||
batch.speculative_ids,
|
batch.speculative_ids,
|
||||||
speculative_logits,
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
speculated_length = batch.speculative_ids.shape[-1] if batch.speculative_ids is not None else 0
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids, speculated_length
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
@ -1062,20 +1064,24 @@ class FlashCausalLM(Model):
|
|||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
all_top_tokens = []
|
||||||
top_token_ids,
|
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||||
clean_up_tokenization_spaces=False,
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
skip_special_tokens=False,
|
top_token_ids,
|
||||||
)
|
clean_up_tokenization_spaces=False,
|
||||||
special_toptokens = [
|
skip_special_tokens=False,
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
)
|
||||||
]
|
special_toptokens = [
|
||||||
top_tokens = Tokens(
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
top_token_ids,
|
]
|
||||||
top_token_logprobs,
|
top_tokens = Tokens(
|
||||||
toptoken_texts,
|
top_token_ids,
|
||||||
special_toptokens,
|
top_token_logprobs,
|
||||||
)
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
all_top_tokens.append(top_tokens)
|
||||||
|
top_tokens = all_top_tokens
|
||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
|
||||||
|
@ -95,5 +95,5 @@ class Generation:
|
|||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
|
||||||
)
|
)
|
||||||
|
@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
next_ids = next_ids.view(B * S)
|
next_ids = next_ids.view(B * S)
|
||||||
scores = scores.view(B * S, -1)
|
allscores = scores.view(B * S, -1)
|
||||||
|
alllogprobs = torch.log_softmax(allscores, -1)
|
||||||
|
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
accepted_ids = []
|
accepted_ids = []
|
||||||
@ -305,15 +306,14 @@ class HeterogeneousNextTokenChooser:
|
|||||||
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
||||||
)
|
)
|
||||||
next_ids = next_ids[indices]
|
next_ids = next_ids[indices]
|
||||||
scores = scores[indices]
|
|
||||||
indices = torch.arange(B, device=input_ids.device) * S
|
indices = torch.arange(B, device=input_ids.device) * S
|
||||||
if speculative_scores is not None:
|
if speculative_scores is not None:
|
||||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||||
else:
|
else:
|
||||||
accepted_ids = torch.ones_like(next_ids)
|
accepted_ids = torch.ones_like(next_ids)
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
|
||||||
|
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
if speculative_scores is not None:
|
if speculative_scores is not None:
|
||||||
@ -327,7 +327,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -436,7 +436,7 @@ class HeterogeneousSampling:
|
|||||||
|
|
||||||
|
|
||||||
def batch_top_tokens(
|
def batch_top_tokens(
|
||||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor, speculative_length: int
|
||||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||||
"""Find the top n most likely tokens for a batch of generations.
|
"""Find the top n most likely tokens for a batch of generations.
|
||||||
|
|
||||||
@ -448,12 +448,16 @@ def batch_top_tokens(
|
|||||||
if max_top_n == 0:
|
if max_top_n == 0:
|
||||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
n = speculative_length + 1
|
||||||
|
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(n)
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculative_length + 1)]
|
||||||
|
|
||||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
|
||||||
|
|
||||||
nth_highest = torch.gather(
|
nth_highest = torch.gather(
|
||||||
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||||
)
|
)
|
||||||
@ -471,13 +475,30 @@ def batch_top_tokens(
|
|||||||
top_indices = top_k.indices.tolist()
|
top_indices = top_k.indices.tolist()
|
||||||
top_values = top_k.values.tolist()
|
top_values = top_k.values.tolist()
|
||||||
|
|
||||||
return (
|
batch_top_token_ids = []
|
||||||
[
|
batch_top_token_logprobs = []
|
||||||
idxs[:n] if req_n > 0 else []
|
accepted_ids = accepted_ids.tolist()
|
||||||
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
for i, n_accepted_ids in enumerate(accepted_ids):
|
||||||
],
|
_top_indices = top_indices[n * i: n * (i + 1)]
|
||||||
[
|
_top_values = top_values[n * i: n * (i + 1)]
|
||||||
vals[:n] if req_n > 0 else []
|
_top_n_ishes = top_n_ishes[n * i: n * (i + 1)]
|
||||||
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
_top_n_tokens = top_n_tokens[n * i: n * (i + 1)]
|
||||||
],
|
_top_indices = _top_indices[:n_accepted_ids]
|
||||||
)
|
_top_values = _top_values[:n_accepted_ids]
|
||||||
|
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
||||||
|
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
|
||||||
|
|
||||||
|
row_top_token_ids = []
|
||||||
|
row_top_token_logprobs = []
|
||||||
|
|
||||||
|
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
|
||||||
|
indices = idxs[:n] if req_n > 0 else []
|
||||||
|
values = vals[:n] if req_n > 0 else []
|
||||||
|
|
||||||
|
row_top_token_ids.append(indices)
|
||||||
|
row_top_token_logprobs.append(values)
|
||||||
|
|
||||||
|
batch_top_token_ids.append(row_top_token_ids)
|
||||||
|
batch_top_token_logprobs.append(row_top_token_logprobs)
|
||||||
|
|
||||||
|
return batch_top_token_ids, batch_top_token_logprobs
|
||||||
|
Loading…
Reference in New Issue
Block a user