mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
dtype default to None instead of float16 and each model could set it's default type according to the platform
cuda: float16 if dtype is None cpu: bfloat16 if dtype is None Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
8672cad2cb
commit
45bf7597ac
@ -76,13 +76,11 @@ def get_model(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if dtype is None:
|
if dtype == "float16":
|
||||||
dtype = torch.float16
|
|
||||||
elif dtype == "float16":
|
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif dtype == "bfloat16":
|
elif dtype == "bfloat16":
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
else:
|
elif dtype is not None:
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
|
@ -51,7 +51,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -492,7 +492,7 @@ class CausalLM(Model):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -174,7 +174,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -33,7 +33,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -42,7 +42,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
self.device, self.dtype = device, dtype
|
self.device, self.dtype = device, dtype
|
||||||
|
|
||||||
config = IdeficsConfig.from_pretrained(
|
config = IdeficsConfig.from_pretrained(
|
||||||
|
@ -560,7 +560,7 @@ class IdeficsCausalLM(Model):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -52,7 +52,7 @@ class MPTSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -31,7 +31,7 @@ class OPTSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -23,7 +23,7 @@ class RW(CausalLM):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -30,7 +30,7 @@ class SantaCoder(CausalLM):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -541,7 +541,7 @@ class Seq2SeqLM(Model):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -34,7 +34,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user