mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
set kv cache dtype
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
92fa7ac7e9
commit
78ca1414b7
@ -950,10 +950,14 @@ class FlashCausalLM(Model):
|
|||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
# Float16 doesn't exist on target.
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if (
|
||||||
|
quantize in ["awq", "exl2", "gptq", "marlin"]
|
||||||
|
and dtype == torch.float16
|
||||||
|
):
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
kv_cache_dtype = torch.bfloat16
|
||||||
init_cpu_threads_env(rank_id=rank, world_size=world_size)
|
init_cpu_threads_env(rank_id=rank, world_size=world_size)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{model_class} is only available on GPU")
|
raise NotImplementedError(f"{model_class} is only available on GPU")
|
||||||
|
Loading…
Reference in New Issue
Block a user