mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fix other issue and make code pass on cpu.
This commit is contained in:
parent
b2575fd18d
commit
336ea37637
@ -76,20 +76,24 @@ class CT2CausalLM(Model):
|
||||
else:
|
||||
self.ct2_device = "cpu"
|
||||
|
||||
if dtype == torch.float16:
|
||||
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||
ct2_compute_type = "float16"
|
||||
elif dtype == torch.float16:
|
||||
elif dtype == torch.bfloat16 and self.ct2_device == "cuda":
|
||||
ct2_compute_type = "bfloat16"
|
||||
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
|
||||
# float16 is not available on CPU
|
||||
# and int16 has no stable implementation
|
||||
ct2_compute_type = "float32"
|
||||
else:
|
||||
# default, int8 quantization.
|
||||
# Fastest and lowest
|
||||
|
||||
if "cuda" in self.ct2_device:
|
||||
# int8 for int8 layers, float16 for non-quantized layers
|
||||
ct2_compute_type = "int8_float16"
|
||||
else:
|
||||
# int8 for int8 layers, float32 for non-quantized layers
|
||||
ct2_compute_type = "int8"
|
||||
# print("cpu is currently experimental due to"
|
||||
# " sampling based / non-greedy next_token"
|
||||
# " of code only working in float16.")
|
||||
|
||||
# Start CT2 - conversion
|
||||
out_dir = (
|
||||
Path(HUGGINGFACE_HUB_CACHE)
|
||||
@ -150,7 +154,7 @@ class CT2CausalLM(Model):
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=torch.int32,
|
||||
device=torch.device("cuda"),
|
||||
device=torch.device(self.ct2_device),
|
||||
)
|
||||
|
||||
@property
|
||||
@ -170,8 +174,8 @@ class CT2CausalLM(Model):
|
||||
# logits = self.ct2_model.forward_batch(
|
||||
# tokens_in
|
||||
# )
|
||||
# logits = torch.as_tensor(logits, device="cuda")
|
||||
# logits = logits.to("cuda").to(torch.float16)
|
||||
# logits = torch.as_tensor(logits, device=all_input_ids.device)
|
||||
# logits = logits.to(torch.float16) if all_input_ids.device.type == "cuda" else logits.to(torch.float32)
|
||||
# return logits, None
|
||||
|
||||
# def forward_greedy_logits(
|
||||
@ -207,22 +211,27 @@ class CT2CausalLM(Model):
|
||||
.to(torch.int32)
|
||||
)
|
||||
# lengths of the padded ids_input, i.e. how often 1234567 is used.
|
||||
lengths = torch.from_numpy(np.array(input_lengths, dtype=np.int32)).to(
|
||||
ids_input.device
|
||||
)
|
||||
|
||||
if self.ct2_device == "cpu":
|
||||
ids_input, lengths = ids_input.numpy(), lengths.numpy()
|
||||
lengths = np.array(input_lengths, dtype=np.int32)
|
||||
|
||||
if self.ct2_device == "cuda":
|
||||
lengths = torch.from_numpy(lengths).to(
|
||||
self.ct2_device
|
||||
)
|
||||
elif self.ct2_device == "cpu":
|
||||
ids_input = ids_input.numpy()
|
||||
|
||||
ids_input = ctranslate2.StorageView.from_array(ids_input)
|
||||
lengths = ctranslate2.StorageView.from_array(lengths)
|
||||
# now, forward through the network
|
||||
logits = self.ct2_model.forward_batch(ids_input, lengths)
|
||||
logits = torch.as_tensor(logits, device=self.ct2_device)
|
||||
# continue with logits as torch tensor, move it to dtype
|
||||
|
||||
# continue with logits as torch tensor
|
||||
if self.ct2_device == "cpu":
|
||||
logits = logits.to(self.ct2_device).to(torch.float32)
|
||||
# logits is a float32 torch cpu tensor
|
||||
logits = torch.from_numpy(np.asarray(logits))
|
||||
else:
|
||||
logits = logits.to("cuda").to(torch.float16)
|
||||
# logits is a float16 torch cuda tensor
|
||||
logits = torch.as_tensor(logits, device=self.ct2_device)
|
||||
return logits, None
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
|
@ -42,7 +42,7 @@ class StaticWarper:
|
||||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if torch.cuda.is_available():
|
||||
if scores.device.type == "cuda":
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
Loading…
Reference in New Issue
Block a user