mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
revert default dtype
This commit is contained in:
parent
0d68619efa
commit
6d8e3659a9
@ -335,16 +335,9 @@ def get_model(
|
|||||||
# fbgemm kernels are fp8xfp8->bf16
|
# fbgemm kernels are fp8xfp8->bf16
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
config_dtype = config_dict.get("torch_dtype", None)
|
# Keep it as default for now and let
|
||||||
# Only use the config dtype if its one of TGI's supported dtype
|
# every model resolve their own default dtype.
|
||||||
if config_dtype == "float16":
|
dtype = None
|
||||||
dtype = torch.float16
|
|
||||||
elif config_dtype == "bfloat16":
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
# Keep it as default for now and let
|
|
||||||
# every model resolve their own default dtype.
|
|
||||||
dtype = None
|
|
||||||
elif dtype == "float16":
|
elif dtype == "float16":
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif dtype == "bfloat16":
|
elif dtype == "bfloat16":
|
||||||
|
Loading…
Reference in New Issue
Block a user