mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
parent
ece7ffa40a
commit
53aa9194c8
@ -237,20 +237,12 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "t5":
|
elif model_type == "t5":
|
||||||
if sharded:
|
return T5Sharded(
|
||||||
return T5Sharded(
|
model_id,
|
||||||
model_id,
|
revision,
|
||||||
revision,
|
quantize=quantize,
|
||||||
quantize=quantize,
|
trust_remote_code=trust_remote_code,
|
||||||
trust_remote_code=trust_remote_code,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
return Seq2SeqLM(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
|
@ -42,25 +42,31 @@ class StaticWarper:
|
|||||||
self.static_next_logprob = None
|
self.static_next_logprob = None
|
||||||
|
|
||||||
def __call__(self, scores):
|
def __call__(self, scores):
|
||||||
if self.cuda_graph is None:
|
if torch.cuda.is_available():
|
||||||
self.static_scores = scores
|
if self.cuda_graph is None:
|
||||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
self.static_scores = scores
|
||||||
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||||
local_scores = self.static_scores
|
local_scores = self.static_scores
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
local_scores = warper(None, local_scores)
|
local_scores = warper(None, local_scores)
|
||||||
|
|
||||||
self.static_warped_scores = local_scores
|
self.static_warped_scores = local_scores
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
self.static_next_logprob = torch.log_softmax(
|
self.static_next_logprob = torch.log_softmax(
|
||||||
self.static_warped_scores, -1
|
self.static_warped_scores, -1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.static_scores.copy_(scores)
|
self.static_scores.copy_(scores)
|
||||||
self.cuda_graph.replay()
|
self.cuda_graph.replay()
|
||||||
|
|
||||||
return self.static_warped_scores, self.static_next_logprob
|
return self.static_warped_scores, self.static_next_logprob
|
||||||
|
|
||||||
|
# CPU branch
|
||||||
|
for warper in self.warpers:
|
||||||
|
scores = warper(None, scores)
|
||||||
|
return scores, torch.log_softmax(scores, -1)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
@lru_cache(10)
|
||||||
|
Loading…
Reference in New Issue
Block a user