From 336ea376378967c9a402f5235e1ce01c351acfbd Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Mon, 24 Jul 2023 11:03:02 +0200 Subject: [PATCH] fix other issue and make code pass on cpu. --- .../models/ct2_causal_lm.py | 47 +++++++++++-------- .../utils/logits_process.py | 2 +- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/server/text_generation_server/models/ct2_causal_lm.py b/server/text_generation_server/models/ct2_causal_lm.py index 6334efe7..8314830a 100644 --- a/server/text_generation_server/models/ct2_causal_lm.py +++ b/server/text_generation_server/models/ct2_causal_lm.py @@ -76,20 +76,24 @@ class CT2CausalLM(Model): else: self.ct2_device = "cpu" - if dtype == torch.float16: + if dtype == torch.float16 and self.ct2_device == "cuda": ct2_compute_type = "float16" - elif dtype == torch.float16: + elif dtype == torch.bfloat16 and self.ct2_device == "cuda": ct2_compute_type = "bfloat16" + elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]: + # float16 is not available on CPU + # and int16 has no stable implementation + ct2_compute_type = "float32" else: # default, int8 quantization. - # Fastest and lowest + if "cuda" in self.ct2_device: + # int8 for int8 layers, float16 for non-quantized layers ct2_compute_type = "int8_float16" else: + # int8 for int8 layers, float32 for non-quantized layers ct2_compute_type = "int8" - # print("cpu is currently experimental due to" - # " sampling based / non-greedy next_token" - # " of code only working in float16.") + # Start CT2 - conversion out_dir = ( Path(HUGGINGFACE_HUB_CACHE) @@ -150,7 +154,7 @@ class CT2CausalLM(Model): tokenizer=tokenizer, requires_padding=True, dtype=torch.int32, - device=torch.device("cuda"), + device=torch.device(self.ct2_device), ) @property @@ -170,8 +174,8 @@ class CT2CausalLM(Model): # logits = self.ct2_model.forward_batch( # tokens_in # ) - # logits = torch.as_tensor(logits, device="cuda") - # logits = logits.to("cuda").to(torch.float16) + # logits = torch.as_tensor(logits, device=all_input_ids.device) + # logits = logits.to(torch.float16) if all_input_ids.device.type == "cuda" else logits.to(torch.float32) # return logits, None # def forward_greedy_logits( @@ -207,22 +211,27 @@ class CT2CausalLM(Model): .to(torch.int32) ) # lengths of the padded ids_input, i.e. how often 1234567 is used. - lengths = torch.from_numpy(np.array(input_lengths, dtype=np.int32)).to( - ids_input.device - ) - - if self.ct2_device == "cpu": - ids_input, lengths = ids_input.numpy(), lengths.numpy() + lengths = np.array(input_lengths, dtype=np.int32) + + if self.ct2_device == "cuda": + lengths = torch.from_numpy(lengths).to( + self.ct2_device + ) + elif self.ct2_device == "cpu": + ids_input = ids_input.numpy() + ids_input = ctranslate2.StorageView.from_array(ids_input) lengths = ctranslate2.StorageView.from_array(lengths) # now, forward through the network logits = self.ct2_model.forward_batch(ids_input, lengths) - logits = torch.as_tensor(logits, device=self.ct2_device) - # continue with logits as torch tensor, move it to dtype + + # continue with logits as torch tensor if self.ct2_device == "cpu": - logits = logits.to(self.ct2_device).to(torch.float32) + # logits is a float32 torch cpu tensor + logits = torch.from_numpy(np.asarray(logits)) else: - logits = logits.to("cuda").to(torch.float16) + # logits is a float16 torch cuda tensor + logits = torch.as_tensor(logits, device=self.ct2_device) return logits, None @tracer.start_as_current_span("generate_token") diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index f424eae4..e29321a1 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -42,7 +42,7 @@ class StaticWarper: self.static_next_logprob = None def __call__(self, scores): - if torch.cuda.is_available(): + if scores.device.type == "cuda": if self.cuda_graph is None: self.static_scores = scores self.cuda_graph = torch.cuda.CUDAGraph()