add shared pool

This commit is contained in:
OlivierDehaene 2023-05-25 18:26:41 +02:00
parent b973c101c5
commit 7e53903ca4

View File

@ -13,6 +13,22 @@ from transformers import (
TypicalLogitsWarper,
)
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
class CUDAGraphWrapper:
def __init__(self):
self.cuda_graph = None
self.static_tensors = {}
# We want the LRU to be as big as possible as creating cuda graphs is expensive. However, each graph holds a tiny
# bit of GPU memory, so we still need to be careful
@lru_cache(512)
def get_cuda_graph_wrapper(warper_name, batch_size):
"""warper_name and batch_size are only used as keys"""
return CUDAGraphWrapper()
class StaticWarper:
def __init__(
@ -44,7 +60,7 @@ class StaticWarper:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph):
with torch.cuda.graph(self.cuda_graph, pool=mempool):
local_scores = self.static_scores
for warper in self.warpers:
local_scores = warper(None, local_scores)
@ -129,18 +145,6 @@ class HeterogeneousTemperatureLogitsWarper:
return self
class CUDAGraphWrapper:
def __init__(self):
self.cuda_graph = None
self.static_tensors = {}
@lru_cache(512)
def get_cuda_graph_wrapper(warper_name, batch_size):
"""warper_name and batch_size are only used as keys"""
return CUDAGraphWrapper()
class HeterogeneousTopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
@ -181,7 +185,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
sorted_logits, sorted_indices = torch.sort(
@ -330,7 +334,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
# calculate entropy