mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Idefics force bfloat16
This commit is contained in:
parent
96a982ad8f
commit
b8952b2b32
@ -39,7 +39,8 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
# 9b seems to work correctly enough in float16, but 80b seems
|
# 9b seems to work correctly enough in float16, but 80b seems
|
||||||
# to be really saturating for f16.
|
# to be really saturating for f16.
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
if dtype is None or dtype == torch.float16:
|
||||||
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
Loading…
Reference in New Issue
Block a user