Defaulting to bf16.

This commit is contained in:
Nicolas Patry 2023-08-15 18:38:42 +02:00
parent defc477a03
commit 8cda9ca2f7

View File

@ -10,7 +10,9 @@ from transformers import (
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
from text_generation_server.models.custom_modeling.idefics_processing import (
IdeficsProcessor,
)
from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
@ -35,7 +37,9 @@ class IDEFICSSharded(IdeficsCausalLM):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32