From b2fb845923eaec07de43c23405476a1fd8885900 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 15:37:27 +0000 Subject: [PATCH] Fixing sharding. --- server/text_generation_server/models/flash_causal_lm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f2e66d56..a5da215a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -881,7 +881,9 @@ class FlashCausalLM(Model): model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) self.num_layers = config.num_hidden_layers - self.num_kv_heads = config.num_key_value_heads + + # Validation is done in the model itself + self.num_kv_heads = config.num_key_value_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {}