mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
faster cumsum
This commit is contained in:
parent
caa9608347
commit
d3cb0d3b83
@ -286,7 +286,9 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -289,7 +289,9 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -325,7 +325,9 @@ def test_batch_concatenate(
|
|||||||
)
|
)
|
||||||
assert generations[2].generated_text.generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
@ -129,6 +129,18 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAGraphWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self.cuda_graph = None
|
||||||
|
self.static_tensors = {}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(512)
|
||||||
|
def get_cuda_graph_wrapper(warper_name, batch_size):
|
||||||
|
"""warper_name and batch_size are only used as keys"""
|
||||||
|
return CUDAGraphWrapper()
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
@ -158,13 +170,30 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
self.cuda_graph_wrapper = get_cuda_graph_wrapper("top_p_warper", len(top_p))
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
if self.cuda_graph_wrapper.cuda_graph is None:
|
||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
self.cuda_graph_wrapper.static_tensors["scores"] = scores
|
||||||
|
self.cuda_graph_wrapper.static_tensors[
|
||||||
|
"top_p_opposite"
|
||||||
|
] = self.top_p_opposite
|
||||||
|
|
||||||
|
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
||||||
|
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||||
|
|
||||||
|
sorted_logits, sorted_indices = torch.sort(
|
||||||
|
local_scores, descending=False
|
||||||
|
)
|
||||||
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
# This is way faster for some reason
|
||||||
|
for i in range(probs.shape[0]):
|
||||||
|
probs[i] = probs[i].cumsum(dim=-1)
|
||||||
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_opposite
|
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||||
if self.min_tokens_to_keep > 1:
|
if self.min_tokens_to_keep > 1:
|
||||||
# Keep at least min_tokens_to_keep
|
# Keep at least min_tokens_to_keep
|
||||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
@ -173,11 +202,24 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
)
|
)
|
||||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
local_scores = local_scores.masked_fill_(
|
||||||
return scores
|
indices_to_remove, self.filter_value
|
||||||
|
)
|
||||||
|
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
|
||||||
|
|
||||||
|
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
|
||||||
|
self.cuda_graph_wrapper.static_tensors["top_p_opposite"].copy_(
|
||||||
|
self.top_p_opposite
|
||||||
|
)
|
||||||
|
self.cuda_graph_wrapper.cuda_graph.replay()
|
||||||
|
|
||||||
|
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.top_p_opposite = self.top_p_opposite[indices]
|
self.top_p_opposite = self.top_p_opposite[indices]
|
||||||
|
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
|
||||||
|
"top_p_warper", len(self.top_p_opposite)
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -279,21 +321,36 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
self.cuda_graph_wrapper = get_cuda_graph_wrapper("typical_p_warper", len(mass))
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.cuda_graph_wrapper.cuda_graph is None:
|
||||||
|
self.cuda_graph_wrapper.static_tensors["scores"] = scores
|
||||||
|
self.cuda_graph_wrapper.static_tensors["mass"] = self.mass
|
||||||
|
|
||||||
|
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
||||||
|
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||||
|
|
||||||
# calculate entropy
|
# calculate entropy
|
||||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
normalized = torch.nn.functional.log_softmax(local_scores, dim=-1)
|
||||||
p = torch.exp(normalized)
|
p = torch.exp(normalized)
|
||||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||||
|
|
||||||
# shift and sort
|
# shift and sort
|
||||||
shifted_scores = torch.abs((-normalized) - ent)
|
shifted_scores = torch.abs((-normalized) - ent)
|
||||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
sorted_scores, sorted_indices = torch.sort(
|
||||||
sorted_logits = scores.gather(-1, sorted_indices)
|
shifted_scores, descending=False
|
||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
)
|
||||||
|
sorted_logits = local_scores.gather(-1, sorted_indices)
|
||||||
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
# This is way faster for some reason
|
||||||
|
for i in range(probs.shape[0]):
|
||||||
|
probs[i] = probs[i].cumsum(dim=-1)
|
||||||
|
|
||||||
# Remove tokens with cumulative mass above the threshold
|
# Remove tokens with cumulative mass above the threshold
|
||||||
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
last_ind = (probs < self.mass).sum(dim=1)
|
||||||
last_ind[last_ind < 0] = 0
|
last_ind[last_ind < 0] = 0
|
||||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||||
1, last_ind.view(-1, 1)
|
1, last_ind.view(-1, 1)
|
||||||
@ -305,11 +362,22 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
1, sorted_indices, sorted_indices_to_remove
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
)
|
)
|
||||||
|
|
||||||
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
local_scores = local_scores.masked_fill_(
|
||||||
return scores
|
indices_to_remove, self.filter_value
|
||||||
|
)
|
||||||
|
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
|
||||||
|
|
||||||
|
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
|
||||||
|
self.cuda_graph_wrapper.static_tensors["mass"].copy_(self.mass)
|
||||||
|
self.cuda_graph_wrapper.cuda_graph.replay()
|
||||||
|
|
||||||
|
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.mass = self.mass[indices]
|
self.mass = self.mass[indices]
|
||||||
|
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
|
||||||
|
"typical_p_warper", len(self.mass)
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import torch
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
LogitsProcessorList,
|
|
||||||
)
|
)
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
@ -166,10 +165,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
):
|
):
|
||||||
warpers = LogitsProcessorList()
|
warpers = []
|
||||||
|
|
||||||
if any(watermark):
|
self.watermark_processor = (
|
||||||
warpers.append(
|
|
||||||
HeterogeneousProcessorWrapper(
|
HeterogeneousProcessorWrapper(
|
||||||
{
|
{
|
||||||
i: WatermarkLogitsProcessor(device=device)
|
i: WatermarkLogitsProcessor(device=device)
|
||||||
@ -177,13 +175,16 @@ class HeterogeneousNextTokenChooser:
|
|||||||
if do_watermark
|
if do_watermark
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if any(watermark)
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in repetition_penalty]):
|
self.repetition_processor = (
|
||||||
warpers.append(
|
|
||||||
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
||||||
repetition_penalty, dtype, device
|
repetition_penalty, dtype, device
|
||||||
)
|
)
|
||||||
|
if any([x != 1.0 for x in repetition_penalty])
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
@ -217,11 +218,22 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.seeds = seeds
|
self.seeds = seeds
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
|
|
||||||
|
self.cuda_graph = None
|
||||||
|
self.static_scores = None
|
||||||
|
self.static_warped_scores = None
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||||
last_token_scores = self.warpers(input_ids, scores)
|
if self.watermark_processor:
|
||||||
next_ids = self.choice(last_token_scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
|
if self.repetition_processor:
|
||||||
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
|
||||||
|
for warper in self.warpers:
|
||||||
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
|
next_ids = self.choice(scores)
|
||||||
next_logprobs = torch.gather(
|
next_logprobs = torch.gather(
|
||||||
torch.log_softmax(last_token_scores, -1), 1, next_ids.view(-1, 1)
|
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
|
||||||
return next_ids, next_logprobs
|
return next_ids, next_logprobs
|
||||||
|
Loading…
Reference in New Issue
Block a user