diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 66db77c7..25cd5970 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -950,10 +950,14 @@ class FlashCausalLM(Model): dtype = default_dtype if dtype is None else dtype else: device = torch.device("cpu") - # Float16 doesn't exist on target. dtype = torch.bfloat16 if dtype is None else dtype - if quantize in ["awq", "exl2", "gptq", "marlin"]: + if ( + quantize in ["awq", "exl2", "gptq", "marlin"] + and dtype == torch.float16 + ): + # Float16 doesn't exist on target. dtype = torch.bfloat16 + kv_cache_dtype = torch.bfloat16 init_cpu_threads_env(rank_id=rank, world_size=world_size) else: raise NotImplementedError(f"{model_class} is only available on GPU")