diff --git a/server/text_generation_server/models/ct2_causal_lm.py b/server/text_generation_server/models/ct2_causal_lm.py index 8314830a..ac6dcd85 100644 --- a/server/text_generation_server/models/ct2_causal_lm.py +++ b/server/text_generation_server/models/ct2_causal_lm.py @@ -18,6 +18,7 @@ import torch import numpy as np import os +import multiprocessing from pathlib import Path from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -71,10 +72,12 @@ class CT2CausalLM(Model): ) # Start CT2 + ct2_generator_kwargs = {"inter_threads": 1} if torch.cuda.is_available(): self.ct2_device = "cuda" else: self.ct2_device = "cpu" + ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2 if dtype == torch.float16 and self.ct2_device == "cuda": ct2_compute_type = "float16" @@ -127,7 +130,8 @@ class CT2CausalLM(Model): # Start CT2 self.ct2_model = ctranslate2.Generator( - str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type + str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type, + **ct2_generator_kwargs ) class DummyModel(torch.nn.Module): @@ -210,7 +214,7 @@ class CT2CausalLM(Model): .flatten(1) .to(torch.int32) ) - # lengths of the padded ids_input, i.e. how often 1234567 is used. + # lengths of the padded ids_input, i.e. how often not pad=1234567 is used. lengths = np.array(input_lengths, dtype=np.int32) if self.ct2_device == "cuda":