diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index b1ba2432..2fd67574 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32