dtype default to None instead of float16 and each model could set it's default type according to the platform

cuda: float16 if dtype is None
cpu: bfloat16 if dtype is None

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2023-09-21 19:20:41 -07:00
parent 8672cad2cb
commit 45bf7597ac
13 changed files with 14 additions and 16 deletions

View File

@ -76,13 +76,11 @@ def get_model(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None: if dtype == "float16":
dtype = torch.float16
elif dtype == "float16":
dtype = torch.float16 dtype = torch.float16
elif dtype == "bfloat16": elif dtype == "bfloat16":
dtype = torch.bfloat16 dtype = torch.bfloat16
else: elif dtype is not None:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:

View File

@ -51,7 +51,7 @@ class BLOOMSharded(CausalLM):
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -492,7 +492,7 @@ class CausalLM(Model):
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -174,7 +174,7 @@ class GalacticaSharded(CausalLM):
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -33,7 +33,7 @@ class GPTNeoxSharded(CausalLM):
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -42,7 +42,7 @@ class IDEFICSSharded(IdeficsCausalLM):
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
self.device, self.dtype = device, dtype self.device, self.dtype = device, dtype
config = IdeficsConfig.from_pretrained( config = IdeficsConfig.from_pretrained(

View File

@ -560,7 +560,7 @@ class IdeficsCausalLM(Model):
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -52,7 +52,7 @@ class MPTSharded(CausalLM):
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -31,7 +31,7 @@ class OPTSharded(CausalLM):
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -23,7 +23,7 @@ class RW(CausalLM):
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -30,7 +30,7 @@ class SantaCoder(CausalLM):
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -541,7 +541,7 @@ class Seq2SeqLM(Model):
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
model = AutoModelForSeq2SeqLM.from_pretrained( model = AutoModelForSeq2SeqLM.from_pretrained(
model_id, model_id,

View File

@ -34,7 +34,7 @@ class T5Sharded(Seq2SeqLM):
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, model_id,