mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
cpu speedup kwargs
This commit is contained in:
parent
336ea37637
commit
be6c9acf46
@ -18,6 +18,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
@ -71,10 +72,12 @@ class CT2CausalLM(Model):
|
||||
)
|
||||
|
||||
# Start CT2
|
||||
ct2_generator_kwargs = {"inter_threads": 1}
|
||||
if torch.cuda.is_available():
|
||||
self.ct2_device = "cuda"
|
||||
else:
|
||||
self.ct2_device = "cpu"
|
||||
ct2_generator_kwargs["intra_threads"] = multiprocessing.cpu_count() // 2
|
||||
|
||||
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||
ct2_compute_type = "float16"
|
||||
@ -127,7 +130,8 @@ class CT2CausalLM(Model):
|
||||
|
||||
# Start CT2
|
||||
self.ct2_model = ctranslate2.Generator(
|
||||
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type
|
||||
str(out_dir), device=self.ct2_device, compute_type=ct2_compute_type,
|
||||
**ct2_generator_kwargs
|
||||
)
|
||||
|
||||
class DummyModel(torch.nn.Module):
|
||||
@ -210,7 +214,7 @@ class CT2CausalLM(Model):
|
||||
.flatten(1)
|
||||
.to(torch.int32)
|
||||
)
|
||||
# lengths of the padded ids_input, i.e. how often 1234567 is used.
|
||||
# lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
|
||||
lengths = np.array(input_lengths, dtype=np.int32)
|
||||
|
||||
if self.ct2_device == "cuda":
|
||||
|
Loading…
Reference in New Issue
Block a user