From 090352d965a474c8055c10ea3ace844b07b9cb2c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 9 May 2023 18:36:20 +0200 Subject: [PATCH] feat(server): use float16 --- server/text_generation_server/models/bloom.py | 2 +- server/text_generation_server/models/causal_lm.py | 2 +- server/text_generation_server/models/galactica.py | 2 +- server/text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/opt.py | 2 +- server/text_generation_server/models/santacoder.py | 2 +- server/text_generation_server/models/seq2seq_lm.py | 2 +- server/text_generation_server/models/t5.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index f528a430..511576da 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 26a9a661..27e1efba 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -452,7 +452,7 @@ class CausalLM(Model): ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 2577f1b1..74e158ab 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -199,7 +199,7 @@ class GalacticaSharded(Galactica): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index e73a3c82..59e5dc24 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 50e5271e..218d8e8a 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -54,7 +54,7 @@ class OPTSharded(OPT): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index a7b09a82..5a142676 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -17,7 +17,7 @@ class SantaCoder(CausalLM): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 4ac5ed3c..72f57f4f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -506,7 +506,7 @@ class Seq2SeqLM(Model): ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 9e8c3c4c..1db4ba87 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32