mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
remove cuda graphs
This commit is contained in:
parent
7e53903ca4
commit
e8fd0e4841
@ -16,20 +16,6 @@ from transformers import (
|
|||||||
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
|
||||||
def __init__(self):
|
|
||||||
self.cuda_graph = None
|
|
||||||
self.static_tensors = {}
|
|
||||||
|
|
||||||
|
|
||||||
# We want the LRU to be as big as possible as creating cuda graphs is expensive. However, each graph holds a tiny
|
|
||||||
# bit of GPU memory, so we still need to be careful
|
|
||||||
@lru_cache(512)
|
|
||||||
def get_cuda_graph_wrapper(warper_name, batch_size):
|
|
||||||
"""warper_name and batch_size are only used as keys"""
|
|
||||||
return CUDAGraphWrapper()
|
|
||||||
|
|
||||||
|
|
||||||
class StaticWarper:
|
class StaticWarper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -102,20 +88,28 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||||
self.penalty = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
|
self.penalty = penalty
|
||||||
|
self.penalty_tensor = torch.tensor(
|
||||||
|
penalty, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
score = torch.where(
|
||||||
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||||
|
)
|
||||||
|
|
||||||
scores.scatter_(1, input_ids, score)
|
scores.scatter_(1, input_ids, score)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.penalty = self.penalty[indices]
|
self.penalty = [self.penalty[i] for i in indices]
|
||||||
|
if any([x != 1.0 for x in self.penalty]):
|
||||||
|
self.penalty_tensor = self.penalty_tensor[indices]
|
||||||
return self
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTemperatureLogitsWarper:
|
class HeterogeneousTemperatureLogitsWarper:
|
||||||
@ -132,17 +126,21 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
||||||
):
|
):
|
||||||
self.temperature = torch.tensor(
|
self.temperature = temperature
|
||||||
|
self.temperature_tensor = torch.tensor(
|
||||||
temperature, dtype=dtype, device=device
|
temperature, dtype=dtype, device=device
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
scores.div_(self.temperature)
|
scores.div_(self.temperature_tensor)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.temperature = self.temperature[indices]
|
self.temperature = [self.temperature[i] for i in indices]
|
||||||
|
if any([x != 1.0 for x in self.temperature]):
|
||||||
|
self.temperature_tensor = self.temperature_tensor[indices]
|
||||||
return self
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
@ -169,28 +167,15 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
filter_value: float = -math.inf,
|
filter_value: float = -math.inf,
|
||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
):
|
):
|
||||||
|
self.top_p = top_p
|
||||||
self.top_p_opposite = 1 - torch.tensor(
|
self.top_p_opposite = 1 - torch.tensor(
|
||||||
top_p, dtype=dtype, device=device
|
top_p, dtype=dtype, device=device
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper("top_p_warper", len(top_p))
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
if self.cuda_graph_wrapper.cuda_graph is None:
|
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||||
self.cuda_graph_wrapper.static_tensors["scores"] = scores
|
|
||||||
self.cuda_graph_wrapper.static_tensors[
|
|
||||||
"top_p_opposite"
|
|
||||||
] = self.top_p_opposite
|
|
||||||
|
|
||||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
|
|
||||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
|
||||||
|
|
||||||
sorted_logits, sorted_indices = torch.sort(
|
|
||||||
local_scores, descending=False
|
|
||||||
)
|
|
||||||
probs = sorted_logits.softmax(dim=-1)
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
# This is way faster for some reason
|
# This is way faster for some reason
|
||||||
for i in range(probs.shape[0]):
|
for i in range(probs.shape[0]):
|
||||||
@ -206,25 +191,16 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
)
|
)
|
||||||
local_scores = local_scores.masked_fill_(
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
indices_to_remove, self.filter_value
|
|
||||||
)
|
|
||||||
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
|
|
||||||
|
|
||||||
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
|
return warped_scores
|
||||||
self.cuda_graph_wrapper.static_tensors["top_p_opposite"].copy_(
|
|
||||||
self.top_p_opposite
|
|
||||||
)
|
|
||||||
self.cuda_graph_wrapper.cuda_graph.replay()
|
|
||||||
|
|
||||||
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
|
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
|
self.top_p = [self.top_p[i] for i in indices]
|
||||||
|
if any([x < 1.0 for x in self.top_p]):
|
||||||
self.top_p_opposite = self.top_p_opposite[indices]
|
self.top_p_opposite = self.top_p_opposite[indices]
|
||||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
|
|
||||||
"top_p_warper", len(self.top_p_opposite)
|
|
||||||
)
|
|
||||||
return self
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||||
@ -264,7 +240,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
if any(disabled):
|
if any(disabled):
|
||||||
self.top_k_disabled_mask = torch.tensor(
|
self.top_k_disabled_mask = torch.tensor(
|
||||||
disabled, dtype=torch.bool, device=device
|
disabled, dtype=torch.bool, device=device
|
||||||
)
|
).view(-1, 1)
|
||||||
else:
|
else:
|
||||||
self.top_k_disabled_mask = None
|
self.top_k_disabled_mask = None
|
||||||
|
|
||||||
@ -292,10 +268,20 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.top_k_tensor = self.top_k_tensor[indices]
|
|
||||||
self.top_k = [self.top_k[i] for i in indices]
|
self.top_k = [self.top_k[i] for i in indices]
|
||||||
|
disabled = [x == 0 for x in self.top_k]
|
||||||
|
|
||||||
|
if not all(disabled):
|
||||||
|
self.top_k_tensor = self.top_k_tensor[indices]
|
||||||
self.max_top_k = max(self.top_k)
|
self.max_top_k = max(self.top_k)
|
||||||
|
|
||||||
|
if self.top_k_disabled_mask is not None:
|
||||||
|
self.top_k_disabled_mask = (
|
||||||
|
self.top_k_disabled_mask[indices] if any(disabled) else None
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||||
@ -322,40 +308,42 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
filter_value: float = -math.inf,
|
filter_value: float = -math.inf,
|
||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
):
|
):
|
||||||
|
self.mass = mass
|
||||||
|
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
||||||
|
|
||||||
|
# 1 is a special value that disables typical_p warping for this member of the batch
|
||||||
|
disabled = [x == 1.0 for x in mass]
|
||||||
|
|
||||||
|
if any(disabled):
|
||||||
|
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
self.disabled_mask = None
|
||||||
|
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper("typical_p_warper", len(mass))
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
if self.cuda_graph_wrapper.cuda_graph is None:
|
|
||||||
self.cuda_graph_wrapper.static_tensors["scores"] = scores
|
|
||||||
self.cuda_graph_wrapper.static_tensors["mass"] = self.mass
|
|
||||||
|
|
||||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
|
|
||||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
|
||||||
|
|
||||||
# calculate entropy
|
# calculate entropy
|
||||||
normalized = torch.nn.functional.log_softmax(local_scores, dim=-1)
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||||
p = torch.exp(normalized)
|
p = torch.exp(normalized)
|
||||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||||
|
|
||||||
# shift and sort
|
# shift and sort
|
||||||
shifted_scores = torch.abs((-normalized) - ent)
|
shifted_scores = torch.abs((-normalized) - ent)
|
||||||
sorted_scores, sorted_indices = torch.sort(
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||||
shifted_scores, descending=False
|
sorted_logits = scores.gather(-1, sorted_indices)
|
||||||
)
|
|
||||||
sorted_logits = local_scores.gather(-1, sorted_indices)
|
|
||||||
probs = sorted_logits.softmax(dim=-1)
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
# This is way faster for some reason
|
# This is way faster for some reason
|
||||||
for i in range(probs.shape[0]):
|
for i in range(probs.shape[0]):
|
||||||
probs[i] = probs[i].cumsum(dim=-1)
|
probs[i] = probs[i].cumsum(dim=-1)
|
||||||
|
|
||||||
# Remove tokens with cumulative mass above the threshold
|
# Remove tokens with cumulative mass above the threshold
|
||||||
last_ind = (probs < self.mass).sum(dim=1)
|
last_ind = (probs < self.mass_tensor).sum(dim=1)
|
||||||
last_ind[last_ind < 0] = 0
|
last_ind[last_ind < 0] = 0
|
||||||
|
|
||||||
|
if self.disabled_mask is not None:
|
||||||
|
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
||||||
|
|
||||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||||
1, last_ind.view(-1, 1)
|
1, last_ind.view(-1, 1)
|
||||||
)
|
)
|
||||||
@ -366,23 +354,24 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
1, sorted_indices, sorted_indices_to_remove
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
)
|
)
|
||||||
|
|
||||||
local_scores = local_scores.masked_fill_(
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
indices_to_remove, self.filter_value
|
|
||||||
)
|
|
||||||
self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores
|
|
||||||
|
|
||||||
self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores)
|
return warped_scores
|
||||||
self.cuda_graph_wrapper.static_tensors["mass"].copy_(self.mass)
|
|
||||||
self.cuda_graph_wrapper.cuda_graph.replay()
|
|
||||||
|
|
||||||
return self.cuda_graph_wrapper.static_tensors["warped_scores"]
|
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.mass = self.mass[indices]
|
self.mass = [self.mass[i] for i in indices]
|
||||||
self.cuda_graph_wrapper = get_cuda_graph_wrapper(
|
disabled = [x == 1.0 for x in self.mass]
|
||||||
"typical_p_warper", len(self.mass)
|
|
||||||
|
if not all(disabled):
|
||||||
|
self.mass_tensor = self.mass_tensor[indices]
|
||||||
|
|
||||||
|
if self.disabled_mask is not None:
|
||||||
|
self.disabled_mask = (
|
||||||
|
self.disabled_mask[indices] if any(disabled) else None
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||||
@ -410,5 +399,7 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
|||||||
if idx in self.processors:
|
if idx in self.processors:
|
||||||
new_processors[i] = self.processors[idx]
|
new_processors[i] = self.processors[idx]
|
||||||
|
|
||||||
|
if new_processors:
|
||||||
self.processors = new_processors
|
self.processors = new_processors
|
||||||
return self
|
return self
|
||||||
|
return None
|
||||||
|
@ -60,9 +60,9 @@ class NextTokenChooser:
|
|||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_processor:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
@ -209,19 +209,18 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
|
|
||||||
num_do_sample = sum(do_sample)
|
if any(do_sample):
|
||||||
if num_do_sample == 0:
|
|
||||||
self.choice = Greedy()
|
|
||||||
else:
|
|
||||||
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||||
|
else:
|
||||||
|
self.choice = Greedy()
|
||||||
|
|
||||||
self.seeds = seeds
|
self.seeds = seeds
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||||
if self.watermark_processor:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
@ -235,12 +234,27 @@ class HeterogeneousNextTokenChooser:
|
|||||||
return next_ids, next_logprobs
|
return next_ids, next_logprobs
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
|
if self.watermark_processor is not None:
|
||||||
|
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||||
|
|
||||||
|
if self.repetition_processor is not None:
|
||||||
|
self.repetition_processor = self.repetition_processor.filter(indices)
|
||||||
|
|
||||||
|
filtered_warpers = []
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
warper.filter(indices)
|
filtered_warper = warper.filter(indices)
|
||||||
if isinstance(self.choice, HeterogeneousSampling):
|
if filtered_warper is not None:
|
||||||
self.choice.filter(indices)
|
filtered_warpers.append(filtered_warper)
|
||||||
|
self.warpers = filtered_warpers
|
||||||
|
|
||||||
self.seeds = [self.seeds[i] for i in indices]
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
|
if any(self.do_sample):
|
||||||
|
self.choice.filter(indices)
|
||||||
|
else:
|
||||||
|
self.choice = Greedy()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user