mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
cpu speedup kwargs
This commit is contained in:
parent
336ea37637
commit
be6c9acf46
@ -18,6 +18,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
@ -71,10 +72,12 @@ class CT2CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Start CT2
|
# Start CT2
|
||||||
|
ct2_generator_kwargs = {"inter_threads": 1}
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.ct2_device = "cuda"
|
self.ct2_device = "cuda"
|
||||||
else:
|
else:
|
||||||
self.ct2_device = "cpu"
|
self.ct2_device = "cpu"
|
||||||
|
ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2
|
||||||
|
|
||||||
if dtype == torch.float16 and self.ct2_device == "cuda":
|
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||||
ct2_compute_type = "float16"
|
ct2_compute_type = "float16"
|
||||||
@ -127,7 +130,8 @@ class CT2CausalLM(Model):
|
|||||||
|
|
||||||
# Start CT2
|
# Start CT2
|
||||||
self.ct2_model = ctranslate2.Generator(
|
self.ct2_model = ctranslate2.Generator(
|
||||||
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type
|
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type,
|
||||||
|
**ct2_generator_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
class DummyModel(torch.nn.Module):
|
class DummyModel(torch.nn.Module):
|
||||||
@ -210,7 +214,7 @@ class CT2CausalLM(Model):
|
|||||||
.flatten(1)
|
.flatten(1)
|
||||||
.to(torch.int32)
|
.to(torch.int32)
|
||||||
)
|
)
|
||||||
# lengths of the padded ids_input, i.e. how often 1234567 is used.
|
# lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
|
||||||
lengths = np.array(input_lengths, dtype=np.int32)
|
lengths = np.array(input_lengths, dtype=np.int32)
|
||||||
|
|
||||||
if self.ct2_device == "cuda":
|
if self.ct2_device == "cuda":
|
||||||
|
Loading…
Reference in New Issue
Block a user