Fix idefics default.

This commit is contained in:
Nicolas Patry 2024-02-29 11:33:38 +01:00
parent 343aa7a197
commit 33bfb417b4

View File

@ -40,7 +40,7 @@ class IDEFICSSharded(IdeficsCausalLM):
device = torch.device(f"cuda:{rank}")
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.bfloat16 if dtype is None else dtype
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype