From 6d8e3659a9e291aad1ff77f791773a975be00c6e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:13:53 +0200 Subject: [PATCH] revert default dtype --- server/text_generation_server/models/__init__.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index aa045ebf..1cd13a2a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -335,16 +335,9 @@ def get_model( # fbgemm kernels are fp8xfp8->bf16 dtype = torch.bfloat16 else: - config_dtype = config_dict.get("torch_dtype", None) - # Only use the config dtype if its one of TGI's supported dtype - if config_dtype == "float16": - dtype = torch.float16 - elif config_dtype == "bfloat16": - dtype = torch.bfloat16 - else: - # Keep it as default for now and let - # every model resolve their own default dtype. - dtype = None + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16":