diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 42b2f686..5c086a73 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -912,7 +912,12 @@ class FlashCausalLM(Model): break if num_kv_heads is None: raise ValueError("Cannot get the number of key/value heads") - self.num_kv_heads = num_kv_heads // self.process_group.size() + self.num_kv_heads = ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {}