From 7ef9aac063c47f7acca27a8061aac7c2fdaf838c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 23 May 2023 12:15:05 +0200 Subject: [PATCH] fix(server): t5 cannot run in f16 --- server/text_generation_server/models/t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index b1ba2432..2fd67574 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32