mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
adapt trust remote code
This commit is contained in:
parent
9b382f9f4a
commit
b2575fd18d
@ -81,14 +81,15 @@ class CT2CausalLM(Model):
|
||||
elif dtype == torch.float16:
|
||||
ct2_compute_type = "bfloat16"
|
||||
else:
|
||||
# default, int8 quantization.
|
||||
# default, int8 quantization.
|
||||
# Fastest and lowest
|
||||
if "cuda" in self.ct2_device:
|
||||
ct2_compute_type = "int8_float16"
|
||||
else:
|
||||
ct2_compute_type = "int8"
|
||||
# raise ValueError("cpu is currently experimental due to"
|
||||
# " sampling based / non-greedy next_token"
|
||||
# " of code only working in float16.")
|
||||
# print("cpu is currently experimental due to"
|
||||
# " sampling based / non-greedy next_token"
|
||||
# " of code only working in float16.")
|
||||
# Start CT2 - conversion
|
||||
out_dir = (
|
||||
Path(HUGGINGFACE_HUB_CACHE)
|
||||
@ -103,7 +104,7 @@ class CT2CausalLM(Model):
|
||||
load_as_float16=True,
|
||||
revision=None,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
converter.convert(
|
||||
output_dir=out_dir,
|
||||
@ -200,12 +201,12 @@ class CT2CausalLM(Model):
|
||||
# CT2 forward requires a list of list of input tokens ids and lengths
|
||||
ids_input = (
|
||||
torch.nested.to_padded_tensor(
|
||||
torch.nested.nested_tensor(all_input_ids), 1234
|
||||
torch.nested.nested_tensor(all_input_ids), 1234567
|
||||
)
|
||||
.flatten(1)
|
||||
.to(torch.int32)
|
||||
)
|
||||
# lengths of the padded ids_input, i.e. how often 1234 is used.
|
||||
# lengths of the padded ids_input, i.e. how often 1234567 is used.
|
||||
lengths = torch.from_numpy(np.array(input_lengths, dtype=np.int32)).to(
|
||||
ids_input.device
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user