revert default dtype

This commit is contained in:
OlivierDehaene 2024-07-22 16:13:53 +02:00
parent 0d68619efa
commit 6d8e3659a9
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -334,13 +334,6 @@ def get_model(
if FBGEMM_DYN_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else:
config_dtype = config_dict.get("torch_dtype", None)
# Only use the config dtype if its one of TGI's supported dtype
if config_dtype == "float16":
dtype = torch.float16
elif config_dtype == "bfloat16":
dtype = torch.bfloat16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.