Default dtype bfloat16.

This commit is contained in:
Nicolas Patry 2024-10-01 10:52:19 +02:00
parent 7ede61bca6
commit d735e46ef5
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -1147,6 +1147,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )