mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
adding suggested changes for os.environ vars and reporting correct torch.dtype on API
This commit is contained in:
parent
e38cda5b9b
commit
79b4620107
@ -73,18 +73,18 @@ class CT2CausalLM(Model):
|
||||
|
||||
# Start CT2
|
||||
ct2_generator_kwargs = {
|
||||
"inter_threads": os.environ.get("TGI_CT2_INTER_THREADS", 1)
|
||||
"inter_threads": int(os.environ.get("TGI_CT2_INTER_THREADS", 1))
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
self.ct2_device = "cuda"
|
||||
ct2_generator_kwargs["intra_threads"] = os.environ.get(
|
||||
ct2_generator_kwargs["intra_threads"] = int(os.environ.get(
|
||||
"TGI_CT2_INTRA_THREADS", 1
|
||||
)
|
||||
))
|
||||
else:
|
||||
self.ct2_device = "cpu"
|
||||
ct2_generator_kwargs["intra_threads"] = os.environ.get(
|
||||
ct2_generator_kwargs["intra_threads"] = int(os.environ.get(
|
||||
"TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2
|
||||
)
|
||||
))
|
||||
|
||||
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||
ct2_compute_type = "float16"
|
||||
@ -166,7 +166,7 @@ class CT2CausalLM(Model):
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int8 if "int8" in ct2_compute_type else torch.float16,
|
||||
device=torch.device(self.ct2_device),
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user