faster cumsum

This commit is contained in:
OlivierDehaene 2023-05-25 17:59:13 +02:00
parent caa9608347
commit d3cb0d3b83
5 changed files with 143 additions and 57 deletions

View File

@ -286,7 +286,9 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range( for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -289,7 +289,9 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range( for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens default_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -325,7 +325,9 @@ def test_batch_concatenate(
) )
assert generations[2].generated_text.generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None

View File

@ -129,6 +129,18 @@ class HeterogeneousTemperatureLogitsWarper:
return self return self
class CUDAGraphWrapper:
def __init__(self):
self.cuda_graph = None
self.static_tensors = {}
@lru_cache(512)
def get_cuda_graph_wrapper(warper_name, batch_size):
"""warper_name and batch_size are only used as keys"""
return CUDAGraphWrapper()
class HeterogeneousTopPLogitsWarper(LogitsWarper): class HeterogeneousTopPLogitsWarper(LogitsWarper):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
@ -158,13 +170,30 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
).unsqueeze(1) ).unsqueeze(1)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
self.cuda_graph_wrapper = get_cuda_graph_wrapper("top_p_warper", len(top_p))
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False) if self.cuda_graph_wrapper.cuda_graph is None:
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) self.cuda_graph_wrapper.static_tensors["scores"] = scores
self.cuda_graph_wrapper.static_tensors[
"top_p_opposite"
] = self.top_p_opposite
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
sorted_logits, sorted_indices = torch.sort(
local_scores, descending=False
)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= self.top_p_opposite sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1: if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep # Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
@ -173,11 +202,24 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove 1, sorted_indices, sorted_indices_to_remove
) )
scores.masked_fill_(indices_to_remove, self.filter_value) local_scores = local_scores.masked_fill_(
return scores indices_to_remove, self.filter_value
)
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
self.cuda_graph_wrapper.static_tensors["top_p_opposite"].copy_(
self.top_p_opposite
)
self.cuda_graph_wrapper.cuda_graph.replay()
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
def filter(self, indices): def filter(self, indices):
self.top_p_opposite = self.top_p_opposite[indices] self.top_p_opposite = self.top_p_opposite[indices]
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
"top_p_warper", len(self.top_p_opposite)
)
return self return self
@ -279,21 +321,36 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
self.cuda_graph_wrapper = get_cuda_graph_wrapper("typical_p_warper", len(mass))
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
if self.cuda_graph_wrapper.cuda_graph is None:
self.cuda_graph_wrapper.static_tensors["scores"] = scores
self.cuda_graph_wrapper.static_tensors["mass"] = self.mass
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
# calculate entropy # calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1) normalized = torch.nn.functional.log_softmax(local_scores, dim=-1)
p = torch.exp(normalized) p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True) ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort # shift and sort
shifted_scores = torch.abs((-normalized) - ent) shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) sorted_scores, sorted_indices = torch.sort(
sorted_logits = scores.gather(-1, sorted_indices) shifted_scores, descending=False
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) )
sorted_logits = local_scores.gather(-1, sorted_indices)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold # Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1) last_ind = (probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0 last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather( sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1) 1, last_ind.view(-1, 1)
@ -305,11 +362,22 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
1, sorted_indices, sorted_indices_to_remove 1, sorted_indices, sorted_indices_to_remove
) )
scores = scores.masked_fill_(indices_to_remove, self.filter_value) local_scores = local_scores.masked_fill_(
return scores indices_to_remove, self.filter_value
)
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
self.cuda_graph_wrapper.static_tensors["mass"].copy_(self.mass)
self.cuda_graph_wrapper.cuda_graph.replay()
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
def filter(self, indices): def filter(self, indices):
self.mass = self.mass[indices] self.mass = self.mass[indices]
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
"typical_p_warper", len(self.mass)
)
return self return self

View File

@ -4,7 +4,6 @@ import torch
from transformers import ( from transformers import (
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
LogitsProcessorList,
) )
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
@ -166,10 +165,9 @@ class HeterogeneousNextTokenChooser:
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
): ):
warpers = LogitsProcessorList() warpers = []
if any(watermark): self.watermark_processor = (
warpers.append(
HeterogeneousProcessorWrapper( HeterogeneousProcessorWrapper(
{ {
i: WatermarkLogitsProcessor(device=device) i: WatermarkLogitsProcessor(device=device)
@ -177,13 +175,16 @@ class HeterogeneousNextTokenChooser:
if do_watermark if do_watermark
} }
) )
if any(watermark)
else None
) )
if any([x != 1.0 for x in repetition_penalty]): self.repetition_processor = (
warpers.append(
HeterogeneousRepetitionPenaltyLogitsProcessor( HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, dtype, device repetition_penalty, dtype, device
) )
if any([x != 1.0 for x in repetition_penalty])
else None
) )
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
@ -217,11 +218,22 @@ class HeterogeneousNextTokenChooser:
self.seeds = seeds self.seeds = seeds
self.do_sample = do_sample self.do_sample = do_sample
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
last_token_scores = self.warpers(input_ids, scores) if self.watermark_processor:
next_ids = self.choice(last_token_scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor:
scores = self.repetition_processor(input_ids, scores)
for warper in self.warpers:
scores = warper(input_ids, scores)
next_ids = self.choice(scores)
next_logprobs = torch.gather( next_logprobs = torch.gather(
torch.log_softmax(last_token_scores, -1), 1, next_ids.view(-1, 1) torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
).view(-1) ).view(-1)
return next_ids, next_logprobs return next_ids, next_logprobs