Let each model resolve their own default dtype.

This commit is contained in:
Nicolas Patry 2023-11-27 10:30:35 +00:00
parent ed2a3f617e
commit 2713b21132

View File

@ -87,7 +87,9 @@ def get_model(
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None: if dtype is None:
dtype = torch.float16 # Keep it as default for now and let
# every model resolve their own default dtype.
dtype = None
elif dtype == "float16": elif dtype == "float16":
dtype = torch.float16 dtype = torch.float16
elif dtype == "bfloat16": elif dtype == "bfloat16":