mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add shared pool
This commit is contained in:
parent
b973c101c5
commit
7e53903ca4
@ -13,6 +13,22 @@ from transformers import (
|
|||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAGraphWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self.cuda_graph = None
|
||||||
|
self.static_tensors = {}
|
||||||
|
|
||||||
|
|
||||||
|
# We want the LRU to be as big as possible as creating cuda graphs is expensive. However, each graph holds a tiny
|
||||||
|
# bit of GPU memory, so we still need to be careful
|
||||||
|
@lru_cache(512)
|
||||||
|
def get_cuda_graph_wrapper(warper_name, batch_size):
|
||||||
|
"""warper_name and batch_size are only used as keys"""
|
||||||
|
return CUDAGraphWrapper()
|
||||||
|
|
||||||
|
|
||||||
class StaticWarper:
|
class StaticWarper:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -44,7 +60,7 @@ class StaticWarper:
|
|||||||
self.static_scores = scores
|
self.static_scores = scores
|
||||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph):
|
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||||
local_scores = self.static_scores
|
local_scores = self.static_scores
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
local_scores = warper(None, local_scores)
|
local_scores = warper(None, local_scores)
|
||||||
@ -129,18 +145,6 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
|
||||||
def __init__(self):
|
|
||||||
self.cuda_graph = None
|
|
||||||
self.static_tensors = {}
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(512)
|
|
||||||
def get_cuda_graph_wrapper(warper_name, batch_size):
|
|
||||||
"""warper_name and batch_size are only used as keys"""
|
|
||||||
return CUDAGraphWrapper()
|
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
@ -181,7 +185,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
|
|
||||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
|
||||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||||
|
|
||||||
sorted_logits, sorted_indices = torch.sort(
|
sorted_logits, sorted_indices = torch.sort(
|
||||||
@ -330,7 +334,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
|
|
||||||
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph):
|
with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool):
|
||||||
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
local_scores = self.cuda_graph_wrapper.static_tensors["scores"]
|
||||||
|
|
||||||
# calculate entropy
|
# calculate entropy
|
||||||
|
Loading…
Reference in New Issue
Block a user