From 7e53903ca482f5ad0ab0116582041eabc7c0243a Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 25 May 2023 18:26:41 +0200 Subject: [PATCH] add shared pool --- .../utils/logits_process.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index bf61322b..08e1c443 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -13,6 +13,22 @@ from transformers import ( TypicalLogitsWarper, ) +mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + + +class CUDAGraphWrapper: + def __init__(self): + self.cuda_graph = None + self.static_tensors = {} + + +# We want the LRU to be as big as possible as creating cuda graphs is expensive. However, each graph holds a tiny +# bit of GPU memory, so we still need to be careful +@lru_cache(512) +def get_cuda_graph_wrapper(warper_name, batch_size): + """warper_name and batch_size are only used as keys""" + return CUDAGraphWrapper() + class StaticWarper: def __init__( @@ -44,7 +60,7 @@ class StaticWarper: self.static_scores = scores self.cuda_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cuda_graph): + with torch.cuda.graph(self.cuda_graph, pool=mempool): local_scores = self.static_scores for warper in self.warpers: local_scores = warper(None, local_scores) @@ -129,18 +145,6 @@ class HeterogeneousTemperatureLogitsWarper: return self -class CUDAGraphWrapper: - def __init__(self): - self.cuda_graph = None - self.static_tensors = {} - - -@lru_cache(512) -def get_cuda_graph_wrapper(warper_name, batch_size): - """warper_name and batch_size are only used as keys""" - return CUDAGraphWrapper() - - class HeterogeneousTopPLogitsWarper(LogitsWarper): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. @@ -181,7 +185,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph): + with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool): local_scores = self.cuda_graph_wrapper.static_tensors["scores"] sorted_logits, sorted_indices = torch.sort( @@ -330,7 +334,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph): + with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool): local_scores = self.cuda_graph_wrapper.static_tensors["scores"] # calculate entropy