fix other issue and make code pass on cpu.

This commit is contained in:
michaelfeil 2023-07-24 11:03:02 +02:00
parent b2575fd18d
commit 336ea37637
2 changed files with 29 additions and 20 deletions

View File

@ -76,20 +76,24 @@ class CT2CausalLM(Model):
else:
self.ct2_device = "cpu"
if dtype == torch.float16:
if dtype == torch.float16 and self.ct2_device == "cuda":
ct2_compute_type = "float16"
elif dtype == torch.float16:
elif dtype == torch.bfloat16 and self.ct2_device == "cuda":
ct2_compute_type = "bfloat16"
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
# float16 is not available on CPU
# and int16 has no stable implementation
ct2_compute_type = "float32"
else:
# default, int8 quantization.
# Fastest and lowest
if "cuda" in self.ct2_device:
# int8 for int8 layers, float16 for non-quantized layers
ct2_compute_type = "int8_float16"
else:
# int8 for int8 layers, float32 for non-quantized layers
ct2_compute_type = "int8"
# print("cpu is currently experimental due to"
# " sampling based / non-greedy next_token"
# " of code only working in float16.")
# Start CT2 - conversion
out_dir = (
Path(HUGGINGFACE_HUB_CACHE)
@ -150,7 +154,7 @@ class CT2CausalLM(Model):
tokenizer=tokenizer,
requires_padding=True,
dtype=torch.int32,
device=torch.device("cuda"),
device=torch.device(self.ct2_device),
)
@property
@ -170,8 +174,8 @@ class CT2CausalLM(Model):
# logits = self.ct2_model.forward_batch(
# tokens_in
# )
# logits = torch.as_tensor(logits, device="cuda")
# logits = logits.to("cuda").to(torch.float16)
# logits = torch.as_tensor(logits, device=all_input_ids.device)
# logits = logits.to(torch.float16) if all_input_ids.device.type == "cuda" else logits.to(torch.float32)
# return logits, None
# def forward_greedy_logits(
@ -207,22 +211,27 @@ class CT2CausalLM(Model):
.to(torch.int32)
)
# lengths of the padded ids_input, i.e. how often 1234567 is used.
lengths = torch.from_numpy(np.array(input_lengths, dtype=np.int32)).to(
ids_input.device
)
if self.ct2_device == "cpu":
ids_input, lengths = ids_input.numpy(), lengths.numpy()
lengths = np.array(input_lengths, dtype=np.int32)
if self.ct2_device == "cuda":
lengths = torch.from_numpy(lengths).to(
self.ct2_device
)
elif self.ct2_device == "cpu":
ids_input = ids_input.numpy()
ids_input = ctranslate2.StorageView.from_array(ids_input)
lengths = ctranslate2.StorageView.from_array(lengths)
# now, forward through the network
logits = self.ct2_model.forward_batch(ids_input, lengths)
logits = torch.as_tensor(logits, device=self.ct2_device)
# continue with logits as torch tensor, move it to dtype
# continue with logits as torch tensor
if self.ct2_device == "cpu":
logits = logits.to(self.ct2_device).to(torch.float32)
# logits is a float32 torch cpu tensor
logits = torch.from_numpy(np.asarray(logits))
else:
logits = logits.to("cuda").to(torch.float16)
# logits is a float16 torch cuda tensor
logits = torch.as_tensor(logits, device=self.ct2_device)
return logits, None
@tracer.start_as_current_span("generate_token")

View File

@ -42,7 +42,7 @@ class StaticWarper:
self.static_next_logprob = None
def __call__(self, scores):
if torch.cuda.is_available():
if scores.device.type == "cuda":
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()