cpu speedup kwargs

This commit is contained in:
michaelfeil 2023-07-24 23:13:32 +02:00
parent 336ea37637
commit be6c9acf46

View File

@ -18,6 +18,7 @@
import torch import torch
import numpy as np import numpy as np
import os import os
import multiprocessing
from pathlib import Path from pathlib import Path
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -71,10 +72,12 @@ class CT2CausalLM(Model):
) )
# Start CT2 # Start CT2
ct2_generator_kwargs = {"inter_threads": 1}
if torch.cuda.is_available(): if torch.cuda.is_available():
self.ct2_device = "cuda" self.ct2_device = "cuda"
else: else:
self.ct2_device = "cpu" self.ct2_device = "cpu"
ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2
if dtype == torch.float16 and self.ct2_device == "cuda": if dtype == torch.float16 and self.ct2_device == "cuda":
ct2_compute_type = "float16" ct2_compute_type = "float16"
@ -127,7 +130,8 @@ class CT2CausalLM(Model):
# Start CT2 # Start CT2
self.ct2_model = ctranslate2.Generator( self.ct2_model = ctranslate2.Generator(
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type,
**ct2_generator_kwargs
) )
class DummyModel(torch.nn.Module): class DummyModel(torch.nn.Module):
@ -210,7 +214,7 @@ class CT2CausalLM(Model):
.flatten(1) .flatten(1)
.to(torch.int32) .to(torch.int32)
) )
# lengths of the padded ids_input, i.e. how often 1234567 is used. # lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
lengths = np.array(input_lengths, dtype=np.int32) lengths = np.array(input_lengths, dtype=np.int32)
if self.ct2_device == "cuda": if self.ct2_device == "cuda":