mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
set kv cache dtype
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
92fa7ac7e9
commit
78ca1414b7
@ -950,10 +950,14 @@ class FlashCausalLM(Model):
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
if (
|
||||
quantize in ["awq", "exl2", "gptq", "marlin"]
|
||||
and dtype == torch.float16
|
||||
):
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16
|
||||
kv_cache_dtype = torch.bfloat16
|
||||
init_cpu_threads_env(rank_id=rank, world_size=world_size)
|
||||
else:
|
||||
raise NotImplementedError(f"{model_class} is only available on GPU")
|
||||
|
Loading…
Reference in New Issue
Block a user