diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index aa045ebf..1cd13a2a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -335,16 +335,9 @@ def get_model( # fbgemm kernels are fp8xfp8->bf16 dtype = torch.bfloat16 else: - config_dtype = config_dict.get("torch_dtype", None) - # Only use the config dtype if its one of TGI's supported dtype - if config_dtype == "float16": - dtype = torch.float16 - elif config_dtype == "bfloat16": - dtype = torch.bfloat16 - else: - # Keep it as default for now and let - # every model resolve their own default dtype. - dtype = None + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16":