adding suggested changes for os.environ vars and reporting correct torch.dtype on API

This commit is contained in:
michaelfeil 2023-07-25 09:05:08 +02:00
parent e38cda5b9b
commit 79b4620107

View File

@ -73,18 +73,18 @@ class CT2CausalLM(Model):
# Start CT2
ct2_generator_kwargs = {
"inter_threads": os.environ.get("TGI_CT2_INTER_THREADS", 1)
"inter_threads": int(os.environ.get("TGI_CT2_INTER_THREADS", 1))
}
if torch.cuda.is_available():
self.ct2_device = "cuda"
ct2_generator_kwargs["intra_threads"] = os.environ.get(
ct2_generator_kwargs["intra_threads"] = int(os.environ.get(
"TGI_CT2_INTRA_THREADS", 1
)
))
else:
self.ct2_device = "cpu"
ct2_generator_kwargs["intra_threads"] = os.environ.get(
ct2_generator_kwargs["intra_threads"] = int(os.environ.get(
"TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2
)
))
if dtype == torch.float16 and self.ct2_device == "cuda":
ct2_compute_type = "float16"
@ -166,7 +166,7 @@ class CT2CausalLM(Model):
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=torch.int32,
dtype=torch.int8 if "int8" in ct2_compute_type else torch.float16,
device=torch.device(self.ct2_device),
)