mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
apply suggested changes
This commit is contained in:
parent
be6c9acf46
commit
e38cda5b9b
@ -72,12 +72,19 @@ class CT2CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Start CT2
|
# Start CT2
|
||||||
ct2_generator_kwargs = {"inter_threads": 1}
|
ct2_generator_kwargs = {
|
||||||
|
"inter_threads": os.environ.get("TGI_CT2_INTER_THREADS", 1)
|
||||||
|
}
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.ct2_device = "cuda"
|
self.ct2_device = "cuda"
|
||||||
|
ct2_generator_kwargs["intra_threads"] = os.environ.get(
|
||||||
|
"TGI_CT2_INTRA_THREADS", 1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.ct2_device = "cpu"
|
self.ct2_device = "cpu"
|
||||||
ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2
|
ct2_generator_kwargs["intra_threads"] = os.environ.get(
|
||||||
|
"TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2
|
||||||
|
)
|
||||||
|
|
||||||
if dtype == torch.float16 and self.ct2_device == "cuda":
|
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||||
ct2_compute_type = "float16"
|
ct2_compute_type = "float16"
|
||||||
@ -86,10 +93,10 @@ class CT2CausalLM(Model):
|
|||||||
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
|
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
|
||||||
# float16 is not available on CPU
|
# float16 is not available on CPU
|
||||||
# and int16 has no stable implementation
|
# and int16 has no stable implementation
|
||||||
ct2_compute_type = "float32"
|
ct2_compute_type = "float32"
|
||||||
else:
|
else:
|
||||||
# default, int8 quantization.
|
# default, int8 quantization.
|
||||||
|
|
||||||
if "cuda" in self.ct2_device:
|
if "cuda" in self.ct2_device:
|
||||||
# int8 for int8 layers, float16 for non-quantized layers
|
# int8 for int8 layers, float16 for non-quantized layers
|
||||||
ct2_compute_type = "int8_float16"
|
ct2_compute_type = "int8_float16"
|
||||||
@ -108,8 +115,8 @@ class CT2CausalLM(Model):
|
|||||||
converter = ctranslate2.converters.TransformersConverter(
|
converter = ctranslate2.converters.TransformersConverter(
|
||||||
model_id,
|
model_id,
|
||||||
activation_scales=None,
|
activation_scales=None,
|
||||||
load_as_float16=True,
|
load_as_float16=ct2_compute_type != "bfloat16",
|
||||||
revision=None,
|
revision=revision,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -125,13 +132,15 @@ class CT2CausalLM(Model):
|
|||||||
)
|
)
|
||||||
if not os.path.exists(out_dir / "model.bin"):
|
if not os.path.exists(out_dir / "model.bin"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"no ctranslate2 for {model_id} found after conversion in {out_dir}"
|
f"no ctranslate2 model for {model_id} found after conversion in {out_dir}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start CT2
|
# Start CT2
|
||||||
self.ct2_model = ctranslate2.Generator(
|
self.ct2_model = ctranslate2.Generator(
|
||||||
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type,
|
str(out_dir),
|
||||||
**ct2_generator_kwargs
|
device=self.ct2_device,
|
||||||
|
compute_type=ct2_compute_type,
|
||||||
|
**ct2_generator_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
class DummyModel(torch.nn.Module):
|
class DummyModel(torch.nn.Module):
|
||||||
@ -216,19 +225,17 @@ class CT2CausalLM(Model):
|
|||||||
)
|
)
|
||||||
# lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
|
# lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
|
||||||
lengths = np.array(input_lengths, dtype=np.int32)
|
lengths = np.array(input_lengths, dtype=np.int32)
|
||||||
|
|
||||||
if self.ct2_device == "cuda":
|
if self.ct2_device == "cuda":
|
||||||
lengths = torch.from_numpy(lengths).to(
|
lengths = torch.from_numpy(lengths).to(self.ct2_device)
|
||||||
self.ct2_device
|
|
||||||
)
|
|
||||||
elif self.ct2_device == "cpu":
|
elif self.ct2_device == "cpu":
|
||||||
ids_input = ids_input.numpy()
|
ids_input = ids_input.numpy()
|
||||||
|
|
||||||
ids_input = ctranslate2.StorageView.from_array(ids_input)
|
ids_input = ctranslate2.StorageView.from_array(ids_input)
|
||||||
lengths = ctranslate2.StorageView.from_array(lengths)
|
lengths = ctranslate2.StorageView.from_array(lengths)
|
||||||
# now, forward through the network
|
# now, forward through the network
|
||||||
logits = self.ct2_model.forward_batch(ids_input, lengths)
|
logits = self.ct2_model.forward_batch(ids_input, lengths)
|
||||||
|
|
||||||
# continue with logits as torch tensor
|
# continue with logits as torch tensor
|
||||||
if self.ct2_device == "cpu":
|
if self.ct2_device == "cpu":
|
||||||
# logits is a float32 torch cpu tensor
|
# logits is a float32 torch cpu tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user