diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 14caa23d..8419fa4f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -241,19 +241,21 @@ class FlashRWLargeAttention(torch.nn.Module): hidden_size = config.hidden_size num_heads = config.n_head - num_heads_kv = config.n_head_kv + # num_heads_kv = config.n_head_kv + num_groups = config.n_head_kv self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.num_groups = num_groups self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) - self.num_groups = num_heads // (num_heads_kv * 2) + # self.num_groups = num_heads // (num_heads_kv * 2) self.num_heads = num_heads // self.num_groups - self.num_heads_kv = num_heads_kv // self.num_groups + # self.num_heads_kv = num_heads_kv // self.num_groups process_group = weights.process_group if process_group.size() > self.num_groups: @@ -264,6 +266,7 @@ class FlashRWLargeAttention(torch.nn.Module): raise NotImplementedError( f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" ) + self.num_groups = self.num_groups // process_group.size() self.query_key_value = TensorParallelColumnLinear.load(