remove cuda graphs

This commit is contained in:
OlivierDehaene 2023-05-26 11:52:13 +02:00
parent 7e53903ca4
commit e8fd0e4841
2 changed files with 136 additions and 131 deletions

View File

@ -16,20 +16,6 @@ from transformers import (
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
class CUDAGraphWrapper:
def __init__(self):
self.cuda_graph = None
self.static_tensors = {}
# We want the LRU to be as big as possible as creating cuda graphs is expensive. However, each graph holds a tiny
# bit of GPU memory, so we still need to be careful
@lru_cache(512)
def get_cuda_graph_wrapper(warper_name, batch_size):
"""warper_name and batch_size are only used as keys"""
return CUDAGraphWrapper()
class StaticWarper:
def __init__(
self,
@ -102,20 +88,28 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score = torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
scores.scatter_(1, input_ids, score)
return scores
def filter(self, indices):
self.penalty = self.penalty[indices]
return self
self.penalty = [self.penalty[i] for i in indices]
if any([x != 1.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper:
@ -132,17 +126,21 @@ class HeterogeneousTemperatureLogitsWarper:
def __init__(
self, temperature: List[float], dtype: torch.dtype, device: torch.device
):
self.temperature = torch.tensor(
self.temperature = temperature
self.temperature_tensor = torch.tensor(
temperature, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature)
scores.div_(self.temperature_tensor)
return scores
def filter(self, indices):
self.temperature = self.temperature[indices]
return self
self.temperature = [self.temperature[i] for i in indices]
if any([x != 1.0 for x in self.temperature]):
self.temperature_tensor = self.temperature_tensor[indices]
return self
return None
class HeterogeneousTopPLogitsWarper(LogitsWarper):
@ -169,62 +167,40 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_p = top_p
self.top_p_opposite = 1 - torch.tensor(
top_p, dtype=dtype, device=device
).unsqueeze(1)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
self.cuda_graph_wrapper = get_cuda_graph_wrapper("top_p_warper", len(top_p))
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
if self.cuda_graph_wrapper.cuda_graph is None:
self.cuda_graph_wrapper.static_tensors["scores"] = scores
self.cuda_graph_wrapper.static_tensors[
"top_p_opposite"
] = self.top_p_opposite
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
sorted_logits, sorted_indices = torch.sort(
local_scores, descending=False
)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
local_scores = local_scores.masked_fill_(
indices_to_remove, self.filter_value
)
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
self.cuda_graph_wrapper.static_tensors["top_p_opposite"].copy_(
self.top_p_opposite
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
self.cuda_graph_wrapper.cuda_graph.replay()
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
return warped_scores
def filter(self, indices):
self.top_p_opposite = self.top_p_opposite[indices]
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
"top_p_warper", len(self.top_p_opposite)
)
return self
self.top_p = [self.top_p[i] for i in indices]
if any([x < 1.0 for x in self.top_p]):
self.top_p_opposite = self.top_p_opposite[indices]
return self
return None
class HeterogeneousTopKLogitsWarper(LogitsWarper):
@ -264,7 +240,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
if any(disabled):
self.top_k_disabled_mask = torch.tensor(
disabled, dtype=torch.bool, device=device
)
).view(-1, 1)
else:
self.top_k_disabled_mask = None
@ -292,10 +268,20 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return scores
def filter(self, indices):
self.top_k_tensor = self.top_k_tensor[indices]
self.top_k = [self.top_k[i] for i in indices]
self.max_top_k = max(self.top_k)
return self
disabled = [x == 0 for x in self.top_k]
if not all(disabled):
self.top_k_tensor = self.top_k_tensor[indices]
self.max_top_k = max(self.top_k)
if self.top_k_disabled_mask is not None:
self.top_k_disabled_mask = (
self.top_k_disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
@ -322,67 +308,70 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.mass = mass
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
# 1 is a special value that disables typical_p warping for this member of the batch
disabled = [x == 1.0 for x in mass]
if any(disabled):
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
else:
self.disabled_mask = None
self.filter_value = filter_value
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
self.min_tokens_to_keep = min_tokens_to_keep
self.cuda_graph_wrapper = get_cuda_graph_wrapper("typical_p_warper", len(mass))
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
if self.cuda_graph_wrapper.cuda_graph is None:
self.cuda_graph_wrapper.static_tensors["scores"] = scores
self.cuda_graph_wrapper.static_tensors["mass"] = self.mass
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
# Remove tokens with cumulative mass above the threshold
last_ind = (probs < self.mass_tensor).sum(dim=1)
last_ind[last_ind < 0] = 0
# calculate entropy
normalized = torch.nn.functional.log_softmax(local_scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
if self.disabled_mask is not None:
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(
shifted_scores, descending=False
)
sorted_logits = local_scores.gather(-1, sorted_indices)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
# Remove tokens with cumulative mass above the threshold
last_ind = (probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
local_scores = local_scores.masked_fill_(
indices_to_remove, self.filter_value
)
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
self.cuda_graph_wrapper.static_tensors["mass"].copy_(self.mass)
self.cuda_graph_wrapper.cuda_graph.replay()
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
return warped_scores
def filter(self, indices):
self.mass = self.mass[indices]
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
"typical_p_warper", len(self.mass)
)
return self
self.mass = [self.mass[i] for i in indices]
disabled = [x == 1.0 for x in self.mass]
if not all(disabled):
self.mass_tensor = self.mass_tensor[indices]
if self.disabled_mask is not None:
self.disabled_mask = (
self.disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousProcessorWrapper(LogitsProcessor):
@ -410,5 +399,7 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
if idx in self.processors:
new_processors[i] = self.processors[idx]
self.processors = new_processors
return self
if new_processors:
self.processors = new_processors
return self
return None

View File

@ -60,9 +60,9 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
if self.watermark_processor:
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor:
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.static_warper is None:
@ -209,19 +209,18 @@ class HeterogeneousNextTokenChooser:
self.warpers = warpers
num_do_sample = sum(do_sample)
if num_do_sample == 0:
self.choice = Greedy()
else:
if any(do_sample):
self.choice = HeterogeneousSampling(do_sample, seeds, device)
else:
self.choice = Greedy()
self.seeds = seeds
self.do_sample = do_sample
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor:
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor:
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
for warper in self.warpers:
@ -235,12 +234,27 @@ class HeterogeneousNextTokenChooser:
return next_ids, next_logprobs
def filter(self, indices):
if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices)
if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices)
filtered_warpers = []
for warper in self.warpers:
warper.filter(indices)
if isinstance(self.choice, HeterogeneousSampling):
self.choice.filter(indices)
filtered_warper = warper.filter(indices)
if filtered_warper is not None:
filtered_warpers.append(filtered_warper)
self.warpers = filtered_warpers
self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices]
if any(self.do_sample):
self.choice.filter(indices)
else:
self.choice = Greedy()
return self
@classmethod