diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 6376788b..a652c1ca 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -179,16 +179,23 @@ class FlashCohereAttention(torch.nn.Module): self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: + rank = weights.process_group.rank() self.q_norm = FastRMSNorm.load( prefix=f"{prefix}.q_norm", weights=weights, eps=config.layer_norm_eps, ) + self.q_norm.weight.data = self.q_norm.weight[ + self.num_heads * rank : self.num_heads * (rank + 1) + ] self.k_norm = FastRMSNorm.load( prefix=f"{prefix}.k_norm", weights=weights, eps=config.layer_norm_eps, ) + self.k_norm.weight.data = self.k_norm.weight[ + self.num_key_value_heads * rank : self.num_key_value_heads * (rank + 1) + ] else: self.q_norm = None self.k_norm = None @@ -221,13 +228,14 @@ class FlashCohereAttention(torch.nn.Module): ], dim=1, ) + if self.use_qk_norm: - query = self.q_norm(query.contiguous()) - key = self.k_norm(key.contiguous()) + query, _ = self.q_norm(query.contiguous()) + key, _ = self.k_norm(key.contiguous()) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_key_value_heads, self.head_size) - value = key.view(-1, self.num_key_value_heads, self.head_size) + value = value.view(-1, self.num_key_value_heads, self.head_size) self.rotary_emb(query, key, cos, sin) @@ -253,8 +261,8 @@ class FlashCohereAttention(torch.nn.Module): paged_attention.attention( attn_output, query, - key, - value, + kv_cache[0], + kv_cache[1], self.num_key_value_heads, self.softmax_scale, block_tables,