From 645369bef7f610677ea2c85f91809a2ced8c6e1c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 8 Oct 2024 08:00:06 -0400 Subject: [PATCH] set kv cache dtype Signed-off-by: Wang, Yi A --- server/text_generation_server/models/flash_causal_lm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0f4bc415..07d65b77 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1122,10 +1122,14 @@ class FlashCausalLM(Model): dtype = default_dtype if dtype is None else dtype else: device = torch.device("cpu") - # Float16 doesn't exist on target. dtype = torch.bfloat16 if dtype is None else dtype - if quantize in ["awq", "exl2", "gptq", "marlin"]: + if ( + quantize in ["awq", "exl2", "gptq", "marlin"] + and dtype == torch.float16 + ): + # Float16 doesn't exist on target. dtype = torch.bfloat16 + kv_cache_dtype = torch.bfloat16 init_cpu_threads_env(rank_id=rank, world_size=world_size) else: raise NotImplementedError(f"{model_class} is only available on GPU")