mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add shared pool
This commit is contained in:
parent
b973c101c5
commit
7e53903ca4
@ -13,6 +13,22 @@ from transformers import (
|
||||
TypicalLogitsWarper,
|
||||
)
|
||||
|
||||
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
def __init__(self):
|
||||
self.cuda_graph = None
|
||||
self.static_tensors = {}
|
||||
|
||||
|
||||
# We want the LRU to be as big as possible as creating cuda graphs is expensive. However, each graph holds a tiny
|
||||
# bit of GPU memory, so we still need to be careful
|
||||
@lru_cache(512)
|
||||
def get_cuda_graph_wrapper(warper_name, batch_size):
|
||||
"""warper_name and batch_size are only used as keys"""
|
||||
return CUDAGraphWrapper()
|
||||
|
||||
|
||||
class StaticWarper:
|
||||
def __init__(
|
||||
@ -44,7 +60,7 @@ class StaticWarper:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
local_scores = warper(None, local_scores)
|
||||
@ -129,18 +145,6 @@ class HeterogeneousTemperatureLogitsWarper:
|
||||
return self
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
def __init__(self):
|
||||
self.cuda_graph = None
|
||||
self.static_tensors = {}
|
||||
|
||||
|
||||
@lru_cache(512)
|
||||
def get_cuda_graph_wrapper(warper_name, batch_size):
|
||||
"""warper_name and batch_size are only used as keys"""
|
||||
return CUDAGraphWrapper()
|
||||
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
@ -181,7 +185,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
|
||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
|
||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(
|
||||
@ -330,7 +334,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
|
||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
|
||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||
|
||||
# calculate entropy
|
||||
|
Loading…
Reference in New Issue
Block a user