Defaulting to bf16.

This commit is contained in:
Nicolas Patry 2023-08-15 18:38:42 +02:00
parent defc477a03
commit 8cda9ca2f7

View File

@ -10,7 +10,9 @@ from transformers import (
) )
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor from text_generation_server.models.custom_modeling.idefics_processing import (
IdeficsProcessor,
)
from transformers import LlamaTokenizerFast from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import ( from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text, IdeficsForVisionText2Text,
@ -35,7 +37,9 @@ class IDEFICSSharded(IdeficsCausalLM):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype # 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32