From 45bf7597accc4e9444d9abcec7d38d271ed75406 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 21 Sep 2023 19:20:41 -0700 Subject: [PATCH] dtype default to None instead of float16 and each model could set it's default type according to the platform cuda: float16 if dtype is None cpu: bfloat16 if dtype is None Signed-off-by: Wang, Yi A --- server/text_generation_server/models/__init__.py | 6 ++---- server/text_generation_server/models/bloom.py | 2 +- server/text_generation_server/models/causal_lm.py | 2 +- server/text_generation_server/models/galactica.py | 2 +- server/text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/idefics.py | 2 +- server/text_generation_server/models/idefics_causal_lm.py | 2 +- server/text_generation_server/models/mpt.py | 2 +- server/text_generation_server/models/opt.py | 2 +- server/text_generation_server/models/rw.py | 2 +- server/text_generation_server/models/santacoder.py | 2 +- server/text_generation_server/models/seq2seq_lm.py | 2 +- server/text_generation_server/models/t5.py | 2 +- 13 files changed, 14 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0d96d43b..95c33028 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -76,13 +76,11 @@ def get_model( dtype: Optional[str], trust_remote_code: bool, ) -> Model: - if dtype is None: - dtype = torch.float16 - elif dtype == "float16": + if dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": dtype = torch.bfloat16 - else: + elif dtype is not None: raise RuntimeError(f"Unknown dtype {dtype}") if "facebook/galactica" in model_id: diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 0151b017..caa3a708 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -51,7 +51,7 @@ class BLOOMSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 696f0fb2..570d6dbe 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -492,7 +492,7 @@ class CausalLM(Model): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index b296c96e..9e750971 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -174,7 +174,7 @@ class GalacticaSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index d4c64dfe..8e637d76 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -33,7 +33,7 @@ class GPTNeoxSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index fa23d1f9..556bff08 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -42,7 +42,7 @@ class IDEFICSSharded(IdeficsCausalLM): dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype self.device, self.dtype = device, dtype config = IdeficsConfig.from_pretrained( diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f4177145..de50bbb6 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -560,7 +560,7 @@ class IdeficsCausalLM(Model): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 19de497c..6cf17fe0 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -52,7 +52,7 @@ class MPTSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index b2b87246..fb25ce9e 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -31,7 +31,7 @@ class OPTSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 802a4aa6..b9d0a065 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -23,7 +23,7 @@ class RW(CausalLM): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 7b269d8e..ac214136 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -30,7 +30,7 @@ class SantaCoder(CausalLM): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 34932c0b..c6231160 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -541,7 +541,7 @@ class Seq2SeqLM(Model): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype model = AutoModelForSeq2SeqLM.from_pretrained( model_id, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 161e69ba..1d0b4c1b 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -34,7 +34,7 @@ class T5Sharded(Seq2SeqLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype config = AutoConfig.from_pretrained( model_id,