diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9540d99e6..3fdc23b2e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -237,20 +237,12 @@ def get_model( ) elif model_type == "t5": - if sharded: - return T5Sharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - else: - return Seq2SeqLM( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) + return T5Sharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if sharded: raise ValueError("sharded is not supported for AutoModel") diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index faa945164..0cbbf8b0f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -42,25 +42,31 @@ class StaticWarper: self.static_next_logprob = None def __call__(self, scores): - if self.cuda_graph is None: - self.static_scores = scores - self.cuda_graph = torch.cuda.CUDAGraph() + if torch.cuda.is_available(): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cuda_graph, pool=mempool): - local_scores = self.static_scores - for warper in self.warpers: - local_scores = warper(None, local_scores) + with torch.cuda.graph(self.cuda_graph, pool=mempool): + local_scores = self.static_scores + for warper in self.warpers: + local_scores = warper(None, local_scores) - self.static_warped_scores = local_scores - # Compute logprobs - self.static_next_logprob = torch.log_softmax( - self.static_warped_scores, -1 - ) + self.static_warped_scores = local_scores + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) - self.static_scores.copy_(scores) - self.cuda_graph.replay() + self.static_scores.copy_(scores) + self.cuda_graph.replay() - return self.static_warped_scores, self.static_next_logprob + return self.static_warped_scores, self.static_next_logprob + + # CPU branch + for warper in self.warpers: + scores = warper(None, scores) + return scores, torch.log_softmax(scores, -1) @lru_cache(10)