From 17594916edfa71418de2adecfff6ca667b78e70a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 11:19:48 +0200 Subject: [PATCH] Fix incorrect cache allocation with multi-query (#2203) We wouldn't allocate any memory in multi-query (1 KV head). Fixes Starcoder et al. --- server/text_generation_server/models/flash_causal_lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 42b2f686..5c086a73 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -912,7 +912,12 @@ class FlashCausalLM(Model): break if num_kv_heads is None: raise ValueError("Cannot get the number of key/value heads") - self.num_kv_heads = num_kv_heads // self.process_group.size() + self.num_kv_heads = ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {}