set kv cache dtype

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-10-08 08:00:06 -04:00
parent 92fa7ac7e9
commit 78ca1414b7

View File

@ -950,10 +950,14 @@ class FlashCausalLM(Model):
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
if quantize in ["awq", "exl2", "gptq", "marlin"]:
if (
quantize in ["awq", "exl2", "gptq", "marlin"]
and dtype == torch.float16
):
# Float16 doesn't exist on target.
dtype = torch.bfloat16
kv_cache_dtype = torch.bfloat16
init_cpu_threads_env(rank_id=rank, world_size=world_size)
else:
raise NotImplementedError(f"{model_class} is only available on GPU")