From 78ca1414b7f06bc4852b9ea934a8a4a82784c50e Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 8 Oct 2024 08:00:06 -0400 Subject: [PATCH] set kv cache dtype Signed-off-by: Wang, Yi A --- server/text_generation_server/models/flash_causal_lm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 66db77c7..25cd5970 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -950,10 +950,14 @@ class FlashCausalLM(Model): dtype = default_dtype if dtype is None else dtype else: device = torch.device("cpu") - # Float16 doesn't exist on target. dtype = torch.bfloat16 if dtype is None else dtype - if quantize in ["awq", "exl2", "gptq", "marlin"]: + if ( + quantize in ["awq", "exl2", "gptq", "marlin"] + and dtype == torch.float16 + ): + # Float16 doesn't exist on target. dtype = torch.bfloat16 + kv_cache_dtype = torch.bfloat16 init_cpu_threads_env(rank_id=rank, world_size=world_size) else: raise NotImplementedError(f"{model_class} is only available on GPU")