fix other issue and make code pass on cpu.

This commit is contained in:
michaelfeil 2023-07-24 11:03:02 +02:00
parent b2575fd18d
commit 336ea37637
2 changed files with 29 additions and 20 deletions

View File

@ -76,20 +76,24 @@ class CT2CausalLM(Model):
else: else:
self.ct2_device = "cpu" self.ct2_device = "cpu"
if dtype == torch.float16: if dtype == torch.float16 and self.ct2_device == "cuda":
ct2_compute_type = "float16" ct2_compute_type = "float16"
elif dtype == torch.float16: elif dtype == torch.bfloat16 and self.ct2_device == "cuda":
ct2_compute_type = "bfloat16" ct2_compute_type = "bfloat16"
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
# float16 is not available on CPU
# and int16 has no stable implementation
ct2_compute_type = "float32"
else: else:
# default, int8 quantization. # default, int8 quantization.
# Fastest and lowest
if "cuda" in self.ct2_device: if "cuda" in self.ct2_device:
# int8 for int8 layers, float16 for non-quantized layers
ct2_compute_type = "int8_float16" ct2_compute_type = "int8_float16"
else: else:
# int8 for int8 layers, float32 for non-quantized layers
ct2_compute_type = "int8" ct2_compute_type = "int8"
# print("cpu is currently experimental due to"
# " sampling based / non-greedy next_token"
# " of code only working in float16.")
# Start CT2 - conversion # Start CT2 - conversion
out_dir = ( out_dir = (
Path(HUGGINGFACE_HUB_CACHE) Path(HUGGINGFACE_HUB_CACHE)
@ -150,7 +154,7 @@ class CT2CausalLM(Model):
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=torch.int32, dtype=torch.int32,
device=torch.device("cuda"), device=torch.device(self.ct2_device),
) )
@property @property
@ -170,8 +174,8 @@ class CT2CausalLM(Model):
# logits = self.ct2_model.forward_batch( # logits = self.ct2_model.forward_batch(
# tokens_in # tokens_in
# ) # )
# logits = torch.as_tensor(logits, device="cuda") # logits = torch.as_tensor(logits, device=all_input_ids.device)
# logits = logits.to("cuda").to(torch.float16) # logits = logits.to(torch.float16) if all_input_ids.device.type == "cuda" else logits.to(torch.float32)
# return logits, None # return logits, None
# def forward_greedy_logits( # def forward_greedy_logits(
@ -207,22 +211,27 @@ class CT2CausalLM(Model):
.to(torch.int32) .to(torch.int32)
) )
# lengths of the padded ids_input, i.e. how often 1234567 is used. # lengths of the padded ids_input, i.e. how often 1234567 is used.
lengths = torch.from_numpy(np.array(input_lengths, dtype=np.int32)).to( lengths = np.array(input_lengths, dtype=np.int32)
ids_input.device
) if self.ct2_device == "cuda":
lengths = torch.from_numpy(lengths).to(
if self.ct2_device == "cpu": self.ct2_device
ids_input, lengths = ids_input.numpy(), lengths.numpy() )
elif self.ct2_device == "cpu":
ids_input = ids_input.numpy()
ids_input = ctranslate2.StorageView.from_array(ids_input) ids_input = ctranslate2.StorageView.from_array(ids_input)
lengths = ctranslate2.StorageView.from_array(lengths) lengths = ctranslate2.StorageView.from_array(lengths)
# now, forward through the network # now, forward through the network
logits = self.ct2_model.forward_batch(ids_input, lengths) logits = self.ct2_model.forward_batch(ids_input, lengths)
logits = torch.as_tensor(logits, device=self.ct2_device)
# continue with logits as torch tensor, move it to dtype # continue with logits as torch tensor
if self.ct2_device == "cpu": if self.ct2_device == "cpu":
logits = logits.to(self.ct2_device).to(torch.float32) # logits is a float32 torch cpu tensor
logits = torch.from_numpy(np.asarray(logits))
else: else:
logits = logits.to("cuda").to(torch.float16) # logits is a float16 torch cuda tensor
logits = torch.as_tensor(logits, device=self.ct2_device)
return logits, None return logits, None
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")

View File

@ -42,7 +42,7 @@ class StaticWarper:
self.static_next_logprob = None self.static_next_logprob = None
def __call__(self, scores): def __call__(self, scores):
if torch.cuda.is_available(): if scores.device.type == "cuda":
if self.cuda_graph is None: if self.cuda_graph is None:
self.static_scores = scores self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph() self.cuda_graph = torch.cuda.CUDAGraph()