Revert "Idefics force bfloat16"

This reverts commit b8952b2b32.
This commit is contained in:
Nicolas Patry 2023-11-23 13:57:02 +00:00
parent b8952b2b32
commit 861acdeab1

View File

@ -39,8 +39,7 @@ class IDEFICSSharded(IdeficsCausalLM):
device = torch.device(f"cuda:{rank}")
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
if dtype is None or dtype == torch.float16:
dtype = torch.bfloat16
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype