mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
apply suggested changes
This commit is contained in:
parent
be6c9acf46
commit
e38cda5b9b
@ -72,12 +72,19 @@ class CT2CausalLM(Model):
|
||||
)
|
||||
|
||||
# Start CT2
|
||||
ct2_generator_kwargs = {"inter_threads": 1}
|
||||
ct2_generator_kwargs = {
|
||||
"inter_threads": os.environ.get("TGI_CT2_INTER_THREADS", 1)
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
self.ct2_device = "cuda"
|
||||
ct2_generator_kwargs["intra_threads"] = os.environ.get(
|
||||
"TGI_CT2_INTRA_THREADS", 1
|
||||
)
|
||||
else:
|
||||
self.ct2_device = "cpu"
|
||||
ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2
|
||||
ct2_generator_kwargs["intra_threads"] = os.environ.get(
|
||||
"TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2
|
||||
)
|
||||
|
||||
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||
ct2_compute_type = "float16"
|
||||
@ -86,10 +93,10 @@ class CT2CausalLM(Model):
|
||||
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
|
||||
# float16 is not available on CPU
|
||||
# and int16 has no stable implementation
|
||||
ct2_compute_type = "float32"
|
||||
ct2_compute_type = "float32"
|
||||
else:
|
||||
# default, int8 quantization.
|
||||
|
||||
# default, int8 quantization.
|
||||
|
||||
if "cuda" in self.ct2_device:
|
||||
# int8 for int8 layers, float16 for non-quantized layers
|
||||
ct2_compute_type = "int8_float16"
|
||||
@ -108,8 +115,8 @@ class CT2CausalLM(Model):
|
||||
converter = ctranslate2.converters.TransformersConverter(
|
||||
model_id,
|
||||
activation_scales=None,
|
||||
load_as_float16=True,
|
||||
revision=None,
|
||||
load_as_float16=ct2_compute_type != "bfloat16",
|
||||
revision=revision,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -125,13 +132,15 @@ class CT2CausalLM(Model):
|
||||
)
|
||||
if not os.path.exists(out_dir / "model.bin"):
|
||||
raise ValueError(
|
||||
f"no ctranslate2 for {model_id} found after conversion in {out_dir}"
|
||||
f"no ctranslate2 model for {model_id} found after conversion in {out_dir}"
|
||||
)
|
||||
|
||||
# Start CT2
|
||||
self.ct2_model = ctranslate2.Generator(
|
||||
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type,
|
||||
**ct2_generator_kwargs
|
||||
str(out_dir),
|
||||
device=self.ct2_device,
|
||||
compute_type=ct2_compute_type,
|
||||
**ct2_generator_kwargs,
|
||||
)
|
||||
|
||||
class DummyModel(torch.nn.Module):
|
||||
@ -216,19 +225,17 @@ class CT2CausalLM(Model):
|
||||
)
|
||||
# lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
|
||||
lengths = np.array(input_lengths, dtype=np.int32)
|
||||
|
||||
|
||||
if self.ct2_device == "cuda":
|
||||
lengths = torch.from_numpy(lengths).to(
|
||||
self.ct2_device
|
||||
)
|
||||
lengths = torch.from_numpy(lengths).to(self.ct2_device)
|
||||
elif self.ct2_device == "cpu":
|
||||
ids_input = ids_input.numpy()
|
||||
|
||||
|
||||
ids_input = ctranslate2.StorageView.from_array(ids_input)
|
||||
lengths = ctranslate2.StorageView.from_array(lengths)
|
||||
# now, forward through the network
|
||||
logits = self.ct2_model.forward_batch(ids_input, lengths)
|
||||
|
||||
|
||||
# continue with logits as torch tensor
|
||||
if self.ct2_device == "cpu":
|
||||
# logits is a float32 torch cpu tensor
|
||||
|
Loading…
Reference in New Issue
Block a user