fix(server): t5 cannot run in f16

This commit is contained in:
OlivierDehaene 2023-05-23 12:15:05 +02:00
parent 91d9beec90
commit 7ef9aac063

View File

@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32