mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fix other issue and make code pass on cpu.
This commit is contained in:
parent
b2575fd18d
commit
336ea37637
@ -76,20 +76,24 @@ class CT2CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
self.ct2_device = "cpu"
|
self.ct2_device = "cpu"
|
||||||
|
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||||
ct2_compute_type = "float16"
|
ct2_compute_type = "float16"
|
||||||
elif dtype == torch.float16:
|
elif dtype == torch.bfloat16 and self.ct2_device == "cuda":
|
||||||
ct2_compute_type = "bfloat16"
|
ct2_compute_type = "bfloat16"
|
||||||
|
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
# float16 is not available on CPU
|
||||||
|
# and int16 has no stable implementation
|
||||||
|
ct2_compute_type = "float32"
|
||||||
else:
|
else:
|
||||||
# default, int8 quantization.
|
# default, int8 quantization.
|
||||||
# Fastest and lowest
|
|
||||||
if "cuda" in self.ct2_device:
|
if "cuda" in self.ct2_device:
|
||||||
|
# int8 for int8 layers, float16 for non-quantized layers
|
||||||
ct2_compute_type = "int8_float16"
|
ct2_compute_type = "int8_float16"
|
||||||
else:
|
else:
|
||||||
|
# int8 for int8 layers, float32 for non-quantized layers
|
||||||
ct2_compute_type = "int8"
|
ct2_compute_type = "int8"
|
||||||
# print("cpu is currently experimental due to"
|
|
||||||
# " sampling based / non-greedy next_token"
|
|
||||||
# " of code only working in float16.")
|
|
||||||
# Start CT2 - conversion
|
# Start CT2 - conversion
|
||||||
out_dir = (
|
out_dir = (
|
||||||
Path(HUGGINGFACE_HUB_CACHE)
|
Path(HUGGINGFACE_HUB_CACHE)
|
||||||
@ -150,7 +154,7 @@ class CT2CausalLM(Model):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=torch.device("cuda"),
|
device=torch.device(self.ct2_device),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -170,8 +174,8 @@ class CT2CausalLM(Model):
|
|||||||
# logits = self.ct2_model.forward_batch(
|
# logits = self.ct2_model.forward_batch(
|
||||||
# tokens_in
|
# tokens_in
|
||||||
# )
|
# )
|
||||||
# logits = torch.as_tensor(logits, device="cuda")
|
# logits = torch.as_tensor(logits, device=all_input_ids.device)
|
||||||
# logits = logits.to("cuda").to(torch.float16)
|
# logits = logits.to(torch.float16) if all_input_ids.device.type == "cuda" else logits.to(torch.float32)
|
||||||
# return logits, None
|
# return logits, None
|
||||||
|
|
||||||
# def forward_greedy_logits(
|
# def forward_greedy_logits(
|
||||||
@ -207,22 +211,27 @@ class CT2CausalLM(Model):
|
|||||||
.to(torch.int32)
|
.to(torch.int32)
|
||||||
)
|
)
|
||||||
# lengths of the padded ids_input, i.e. how often 1234567 is used.
|
# lengths of the padded ids_input, i.e. how often 1234567 is used.
|
||||||
lengths = torch.from_numpy(np.array(input_lengths, dtype=np.int32)).to(
|
lengths = np.array(input_lengths, dtype=np.int32)
|
||||||
ids_input.device
|
|
||||||
)
|
if self.ct2_device == "cuda":
|
||||||
|
lengths = torch.from_numpy(lengths).to(
|
||||||
if self.ct2_device == "cpu":
|
self.ct2_device
|
||||||
ids_input, lengths = ids_input.numpy(), lengths.numpy()
|
)
|
||||||
|
elif self.ct2_device == "cpu":
|
||||||
|
ids_input = ids_input.numpy()
|
||||||
|
|
||||||
ids_input = ctranslate2.StorageView.from_array(ids_input)
|
ids_input = ctranslate2.StorageView.from_array(ids_input)
|
||||||
lengths = ctranslate2.StorageView.from_array(lengths)
|
lengths = ctranslate2.StorageView.from_array(lengths)
|
||||||
# now, forward through the network
|
# now, forward through the network
|
||||||
logits = self.ct2_model.forward_batch(ids_input, lengths)
|
logits = self.ct2_model.forward_batch(ids_input, lengths)
|
||||||
logits = torch.as_tensor(logits, device=self.ct2_device)
|
|
||||||
# continue with logits as torch tensor, move it to dtype
|
# continue with logits as torch tensor
|
||||||
if self.ct2_device == "cpu":
|
if self.ct2_device == "cpu":
|
||||||
logits = logits.to(self.ct2_device).to(torch.float32)
|
# logits is a float32 torch cpu tensor
|
||||||
|
logits = torch.from_numpy(np.asarray(logits))
|
||||||
else:
|
else:
|
||||||
logits = logits.to("cuda").to(torch.float16)
|
# logits is a float16 torch cuda tensor
|
||||||
|
logits = torch.as_tensor(logits, device=self.ct2_device)
|
||||||
return logits, None
|
return logits, None
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
|
@ -42,7 +42,7 @@ class StaticWarper:
|
|||||||
self.static_next_logprob = None
|
self.static_next_logprob = None
|
||||||
|
|
||||||
def __call__(self, scores):
|
def __call__(self, scores):
|
||||||
if torch.cuda.is_available():
|
if scores.device.type == "cuda":
|
||||||
if self.cuda_graph is None:
|
if self.cuda_graph is None:
|
||||||
self.static_scores = scores
|
self.static_scores = scores
|
||||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
Loading…
Reference in New Issue
Block a user