mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
dtype default to None instead of float16 and each model could set it's default type according to the platform
cuda: float16 if dtype is None cpu: bfloat16 if dtype is None Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
8672cad2cb
commit
45bf7597ac
@ -76,13 +76,11 @@ def get_model(
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
elif dtype == "float16":
|
||||
if dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
elif dtype is not None:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
|
@ -51,7 +51,7 @@ class BLOOMSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -492,7 +492,7 @@ class CausalLM(Model):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -174,7 +174,7 @@ class GalacticaSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -33,7 +33,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -42,7 +42,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
self.device, self.dtype = device, dtype
|
||||
|
||||
config = IdeficsConfig.from_pretrained(
|
||||
|
@ -560,7 +560,7 @@ class IdeficsCausalLM(Model):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -52,7 +52,7 @@ class MPTSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -31,7 +31,7 @@ class OPTSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -23,7 +23,7 @@ class RW(CausalLM):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -30,7 +30,7 @@ class SantaCoder(CausalLM):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -541,7 +541,7 @@ class Seq2SeqLM(Model):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
|
@ -34,7 +34,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
|
Loading…
Reference in New Issue
Block a user