mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix(server): t5 cannot run in f16
This commit is contained in:
parent
91d9beec90
commit
7ef9aac063
@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
Loading…
Reference in New Issue
Block a user