diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index c4de21a7..c54b539b 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -10,7 +10,9 @@ from transformers import ( ) from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig -from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor +from text_generation_server.models.custom_modeling.idefics_processing import ( + IdeficsProcessor, +) from transformers import LlamaTokenizerFast from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, @@ -35,7 +37,9 @@ class IDEFICSSharded(IdeficsCausalLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype + # 9b seems to work correctly enough in float16, but 80b seems + # to be really saturating for f16. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32