From e38cda5b9b89d5ed1ef143116a820056620e540e Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Tue, 25 Jul 2023 08:32:55 +0200 Subject: [PATCH] apply suggested changes --- .../models/ct2_causal_lm.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/ct2_causal_lm.py b/server/text_generation_server/models/ct2_causal_lm.py index ac6dcd85..c6e8a161 100644 --- a/server/text_generation_server/models/ct2_causal_lm.py +++ b/server/text_generation_server/models/ct2_causal_lm.py @@ -72,12 +72,19 @@ class CT2CausalLM(Model): ) # Start CT2 - ct2_generator_kwargs = {"inter_threads": 1} + ct2_generator_kwargs = { + "inter_threads": os.environ.get("TGI_CT2_INTER_THREADS", 1) + } if torch.cuda.is_available(): self.ct2_device = "cuda" + ct2_generator_kwargs["intra_threads"] = os.environ.get( + "TGI_CT2_INTRA_THREADS", 1 + ) else: self.ct2_device = "cpu" - ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2 + ct2_generator_kwargs["intra_threads"] = os.environ.get( + "TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2 + ) if dtype == torch.float16 and self.ct2_device == "cuda": ct2_compute_type = "float16" @@ -86,10 +93,10 @@ class CT2CausalLM(Model): elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]: # float16 is not available on CPU # and int16 has no stable implementation - ct2_compute_type = "float32" + ct2_compute_type = "float32" else: - # default, int8 quantization. - + # default, int8 quantization. + if "cuda" in self.ct2_device: # int8 for int8 layers, float16 for non-quantized layers ct2_compute_type = "int8_float16" @@ -108,8 +115,8 @@ class CT2CausalLM(Model): converter = ctranslate2.converters.TransformersConverter( model_id, activation_scales=None, - load_as_float16=True, - revision=None, + load_as_float16=ct2_compute_type != "bfloat16", + revision=revision, low_cpu_mem_usage=True, trust_remote_code=trust_remote_code, ) @@ -125,13 +132,15 @@ class CT2CausalLM(Model): ) if not os.path.exists(out_dir / "model.bin"): raise ValueError( - f"no ctranslate2 for {model_id} found after conversion in {out_dir}" + f"no ctranslate2 model for {model_id} found after conversion in {out_dir}" ) # Start CT2 self.ct2_model = ctranslate2.Generator( - str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type, - **ct2_generator_kwargs + str(out_dir), + device=self.ct2_device, + compute_type=ct2_compute_type, + **ct2_generator_kwargs, ) class DummyModel(torch.nn.Module): @@ -216,19 +225,17 @@ class CT2CausalLM(Model): ) # lengths of the padded ids_input, i.e. how often not pad=1234567 is used. lengths = np.array(input_lengths, dtype=np.int32) - + if self.ct2_device == "cuda": - lengths = torch.from_numpy(lengths).to( - self.ct2_device - ) + lengths = torch.from_numpy(lengths).to(self.ct2_device) elif self.ct2_device == "cpu": ids_input = ids_input.numpy() - + ids_input = ctranslate2.StorageView.from_array(ids_input) lengths = ctranslate2.StorageView.from_array(lengths) # now, forward through the network logits = self.ct2_model.forward_batch(ids_input, lengths) - + # continue with logits as torch tensor if self.ct2_device == "cpu": # logits is a float32 torch cpu tensor