revert default dtype

This commit is contained in:
OlivierDehaene 2024-07-22 16:13:53 +02:00
parent 0d68619efa
commit 6d8e3659a9
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -335,16 +335,9 @@ def get_model(
# fbgemm kernels are fp8xfp8->bf16 # fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16 dtype = torch.bfloat16
else: else:
config_dtype = config_dict.get("torch_dtype", None) # Keep it as default for now and let
# Only use the config dtype if its one of TGI's supported dtype # every model resolve their own default dtype.
if config_dtype == "float16": dtype = None
dtype = torch.float16
elif config_dtype == "bfloat16":
dtype = torch.bfloat16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.
dtype = None
elif dtype == "float16": elif dtype == "float16":
dtype = torch.float16 dtype = torch.float16
elif dtype == "bfloat16": elif dtype == "bfloat16":