fix(server): fix warpers on CPU (#472)

Closes #471
This commit is contained in:
OlivierDehaene 2023-06-20 11:06:10 +02:00 committed by GitHub
parent ece7ffa40a
commit 53aa9194c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 29 deletions

View File

@ -237,20 +237,12 @@ def get_model(
) )
elif model_type == "t5": elif model_type == "t5":
if sharded:
return T5Sharded( return T5Sharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else:
return Seq2SeqLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")

View File

@ -42,6 +42,7 @@ class StaticWarper:
self.static_next_logprob = None self.static_next_logprob = None
def __call__(self, scores): def __call__(self, scores):
if torch.cuda.is_available():
if self.cuda_graph is None: if self.cuda_graph is None:
self.static_scores = scores self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph() self.cuda_graph = torch.cuda.CUDAGraph()
@ -62,6 +63,11 @@ class StaticWarper:
return self.static_warped_scores, self.static_next_logprob return self.static_warped_scores, self.static_next_logprob
# CPU branch
for warper in self.warpers:
scores = warper(None, scores)
return scores, torch.log_softmax(scores, -1)
@lru_cache(10) @lru_cache(10)
def static_warper( def static_warper(