Let each model resolve their own default dtype.

This commit is contained in:
Nicolas Patry 2023-11-27 10:30:35 +00:00
parent ed2a3f617e
commit 2713b21132

View File

@ -87,7 +87,9 @@ def get_model(
trust_remote_code: bool,
) -> Model:
if dtype is None:
dtype = torch.float16
# Keep it as default for now and let
# every model resolve their own default dtype.
dtype = None
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":