apply suggested changes

This commit is contained in:
michaelfeil 2023-07-25 08:32:55 +02:00
parent be6c9acf46
commit e38cda5b9b

View File

@ -72,12 +72,19 @@ class CT2CausalLM(Model):
)
# Start CT2
ct2_generator_kwargs = {"inter_threads": 1}
ct2_generator_kwargs = {
"inter_threads": os.environ.get("TGI_CT2_INTER_THREADS", 1)
}
if torch.cuda.is_available():
self.ct2_device = "cuda"
ct2_generator_kwargs["intra_threads"] = os.environ.get(
"TGI_CT2_INTRA_THREADS", 1
)
else:
self.ct2_device = "cpu"
ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2
ct2_generator_kwargs["intra_threads"] = os.environ.get(
"TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2
)
if dtype == torch.float16 and self.ct2_device == "cuda":
ct2_compute_type = "float16"
@ -108,8 +115,8 @@ class CT2CausalLM(Model):
converter = ctranslate2.converters.TransformersConverter(
model_id,
activation_scales=None,
load_as_float16=True,
revision=None,
load_as_float16=ct2_compute_type != "bfloat16",
revision=revision,
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code,
)
@ -125,13 +132,15 @@ class CT2CausalLM(Model):
)
if not os.path.exists(out_dir / "model.bin"):
raise ValueError(
f"no ctranslate2 for {model_id} found after conversion in {out_dir}"
f"no ctranslate2 model for {model_id} found after conversion in {out_dir}"
)
# Start CT2
self.ct2_model = ctranslate2.Generator(
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type,
**ct2_generator_kwargs
str(out_dir),
device=self.ct2_device,
compute_type=ct2_compute_type,
**ct2_generator_kwargs,
)
class DummyModel(torch.nn.Module):
@ -218,9 +227,7 @@ class CT2CausalLM(Model):
lengths = np.array(input_lengths, dtype=np.int32)
if self.ct2_device == "cuda":
lengths = torch.from_numpy(lengths).to(
self.ct2_device
)
lengths = torch.from_numpy(lengths).to(self.ct2_device)
elif self.ct2_device == "cpu":
ids_input = ids_input.numpy()