mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
faster cumsum
This commit is contained in:
parent
caa9608347
commit
d3cb0d3b83
@ -286,7 +286,9 @@ def test_batch_concatenate(
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -289,7 +289,9 @@ def test_batch_concatenate(
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -325,7 +325,9 @@ def test_batch_concatenate(
|
||||
)
|
||||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
@ -129,6 +129,18 @@ class HeterogeneousTemperatureLogitsWarper:
|
||||
return self
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
def __init__(self):
|
||||
self.cuda_graph = None
|
||||
self.static_tensors = {}
|
||||
|
||||
|
||||
@lru_cache(512)
|
||||
def get_cuda_graph_wrapper(warper_name, batch_size):
|
||||
"""warper_name and batch_size are only used as keys"""
|
||||
return CUDAGraphWrapper()
|
||||
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
@ -158,13 +170,30 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
).unsqueeze(1)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper("top_p_warper", len(top_p))
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
if self.cuda_graph_wrapper.cuda_graph is None:
|
||||
self.cuda_graph_wrapper.static_tensors["scores"] = scores
|
||||
self.cuda_graph_wrapper.static_tensors[
|
||||
"top_p_opposite"
|
||||
] = self.top_p_opposite
|
||||
|
||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(
|
||||
local_scores, descending=False
|
||||
)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
# This is way faster for some reason
|
||||
for i in range(probs.shape[0]):
|
||||
probs[i] = probs[i].cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_opposite
|
||||
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
@ -173,11 +202,24 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
local_scores = local_scores.masked_fill_(
|
||||
indices_to_remove, self.filter_value
|
||||
)
|
||||
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
|
||||
|
||||
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
|
||||
self.cuda_graph_wrapper.static_tensors["top_p_opposite"].copy_(
|
||||
self.top_p_opposite
|
||||
)
|
||||
self.cuda_graph_wrapper.cuda_graph.replay()
|
||||
|
||||
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
|
||||
|
||||
def filter(self, indices):
|
||||
self.top_p_opposite = self.top_p_opposite[indices]
|
||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
|
||||
"top_p_warper", len(self.top_p_opposite)
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@ -279,21 +321,36 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
self.filter_value = filter_value
|
||||
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper("typical_p_warper", len(mass))
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
if self.cuda_graph_wrapper.cuda_graph is None:
|
||||
self.cuda_graph_wrapper.static_tensors["scores"] = scores
|
||||
self.cuda_graph_wrapper.static_tensors["mass"] = self.mass
|
||||
|
||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
normalized = torch.nn.functional.log_softmax(local_scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||
|
||||
# shift and sort
|
||||
shifted_scores = torch.abs((-normalized) - ent)
|
||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||
sorted_logits = scores.gather(-1, sorted_indices)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
sorted_scores, sorted_indices = torch.sort(
|
||||
shifted_scores, descending=False
|
||||
)
|
||||
sorted_logits = local_scores.gather(-1, sorted_indices)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
# This is way faster for some reason
|
||||
for i in range(probs.shape[0]):
|
||||
probs[i] = probs[i].cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||
last_ind = (probs < self.mass).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||
1, last_ind.view(-1, 1)
|
||||
@ -305,11 +362,22 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
|
||||
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
local_scores = local_scores.masked_fill_(
|
||||
indices_to_remove, self.filter_value
|
||||
)
|
||||
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
|
||||
|
||||
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
|
||||
self.cuda_graph_wrapper.static_tensors["mass"].copy_(self.mass)
|
||||
self.cuda_graph_wrapper.cuda_graph.replay()
|
||||
|
||||
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
|
||||
|
||||
def filter(self, indices):
|
||||
self.mass = self.mass[indices]
|
||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
|
||||
"typical_p_warper", len(self.mass)
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
|
@ -4,7 +4,6 @@ import torch
|
||||
from transformers import (
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
LogitsProcessorList,
|
||||
)
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
@ -166,10 +165,9 @@ class HeterogeneousNextTokenChooser:
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
):
|
||||
warpers = LogitsProcessorList()
|
||||
warpers = []
|
||||
|
||||
if any(watermark):
|
||||
warpers.append(
|
||||
self.watermark_processor = (
|
||||
HeterogeneousProcessorWrapper(
|
||||
{
|
||||
i: WatermarkLogitsProcessor(device=device)
|
||||
@ -177,13 +175,16 @@ class HeterogeneousNextTokenChooser:
|
||||
if do_watermark
|
||||
}
|
||||
)
|
||||
if any(watermark)
|
||||
else None
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in repetition_penalty]):
|
||||
warpers.append(
|
||||
self.repetition_processor = (
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
||||
repetition_penalty, dtype, device
|
||||
)
|
||||
if any([x != 1.0 for x in repetition_penalty])
|
||||
else None
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
@ -217,11 +218,22 @@ class HeterogeneousNextTokenChooser:
|
||||
self.seeds = seeds
|
||||
self.do_sample = do_sample
|
||||
|
||||
self.cuda_graph = None
|
||||
self.static_scores = None
|
||||
self.static_warped_scores = None
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||
last_token_scores = self.warpers(input_ids, scores)
|
||||
next_ids = self.choice(last_token_scores)
|
||||
if self.watermark_processor:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
|
||||
for warper in self.warpers:
|
||||
scores = warper(input_ids, scores)
|
||||
|
||||
next_ids = self.choice(scores)
|
||||
next_logprobs = torch.gather(
|
||||
torch.log_softmax(last_token_scores, -1), 1, next_ids.view(-1, 1)
|
||||
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
||||
).view(-1)
|
||||
|
||||
return next_ids, next_logprobs
|
||||
|
Loading…
Reference in New Issue
Block a user