apply suggested changes

This commit is contained in:
michaelfeil 2023-07-25 08:32:55 +02:00
parent be6c9acf46
commit e38cda5b9b

View File

@ -72,12 +72,19 @@ class CT2CausalLM(Model):
) )
# Start CT2 # Start CT2
ct2_generator_kwargs = {"inter_threads": 1} ct2_generator_kwargs = {
"inter_threads": os.environ.get("TGI_CT2_INTER_THREADS", 1)
}
if torch.cuda.is_available(): if torch.cuda.is_available():
self.ct2_device = "cuda" self.ct2_device = "cuda"
ct2_generator_kwargs["intra_threads"] = os.environ.get(
"TGI_CT2_INTRA_THREADS", 1
)
else: else:
self.ct2_device = "cpu" self.ct2_device = "cpu"
ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2 ct2_generator_kwargs["intra_threads"] = os.environ.get(
"TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2
)
if dtype == torch.float16 and self.ct2_device == "cuda": if dtype == torch.float16 and self.ct2_device == "cuda":
ct2_compute_type = "float16" ct2_compute_type = "float16"
@ -86,10 +93,10 @@ class CT2CausalLM(Model):
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]: elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
# float16 is not available on CPU # float16 is not available on CPU
# and int16 has no stable implementation # and int16 has no stable implementation
ct2_compute_type = "float32" ct2_compute_type = "float32"
else: else:
# default, int8 quantization. # default, int8 quantization.
if "cuda" in self.ct2_device: if "cuda" in self.ct2_device:
# int8 for int8 layers, float16 for non-quantized layers # int8 for int8 layers, float16 for non-quantized layers
ct2_compute_type = "int8_float16" ct2_compute_type = "int8_float16"
@ -108,8 +115,8 @@ class CT2CausalLM(Model):
converter = ctranslate2.converters.TransformersConverter( converter = ctranslate2.converters.TransformersConverter(
model_id, model_id,
activation_scales=None, activation_scales=None,
load_as_float16=True, load_as_float16=ct2_compute_type != "bfloat16",
revision=None, revision=revision,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -125,13 +132,15 @@ class CT2CausalLM(Model):
) )
if not os.path.exists(out_dir / "model.bin"): if not os.path.exists(out_dir / "model.bin"):
raise ValueError( raise ValueError(
f"no ctranslate2 for {model_id} found after conversion in {out_dir}" f"no ctranslate2 model for {model_id} found after conversion in {out_dir}"
) )
# Start CT2 # Start CT2
self.ct2_model = ctranslate2.Generator( self.ct2_model = ctranslate2.Generator(
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type, str(out_dir),
**ct2_generator_kwargs device=self.ct2_device,
compute_type=ct2_compute_type,
**ct2_generator_kwargs,
) )
class DummyModel(torch.nn.Module): class DummyModel(torch.nn.Module):
@ -216,19 +225,17 @@ class CT2CausalLM(Model):
) )
# lengths of the padded ids_input, i.e. how often not pad=1234567 is used. # lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
lengths = np.array(input_lengths, dtype=np.int32) lengths = np.array(input_lengths, dtype=np.int32)
if self.ct2_device == "cuda": if self.ct2_device == "cuda":
lengths = torch.from_numpy(lengths).to( lengths = torch.from_numpy(lengths).to(self.ct2_device)
self.ct2_device
)
elif self.ct2_device == "cpu": elif self.ct2_device == "cpu":
ids_input = ids_input.numpy() ids_input = ids_input.numpy()
ids_input = ctranslate2.StorageView.from_array(ids_input) ids_input = ctranslate2.StorageView.from_array(ids_input)
lengths = ctranslate2.StorageView.from_array(lengths) lengths = ctranslate2.StorageView.from_array(lengths)
# now, forward through the network # now, forward through the network
logits = self.ct2_model.forward_batch(ids_input, lengths) logits = self.ct2_model.forward_batch(ids_input, lengths)
# continue with logits as torch tensor # continue with logits as torch tensor
if self.ct2_device == "cpu": if self.ct2_device == "cpu":
# logits is a float32 torch cpu tensor # logits is a float32 torch cpu tensor