This commit is contained in:
OlivierDehaene 2024-04-04 19:11:50 +02:00
parent 4a02d3505f
commit 58a7719e02

View File

@ -179,16 +179,23 @@ class FlashCohereAttention(torch.nn.Module):
self.use_qk_norm = config.use_qk_norm self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm: if self.use_qk_norm:
rank = weights.process_group.rank()
self.q_norm = FastRMSNorm.load( self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm", prefix=f"{prefix}.q_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.q_norm.weight.data = self.q_norm.weight[
self.num_heads * rank : self.num_heads * (rank + 1)
]
self.k_norm = FastRMSNorm.load( self.k_norm = FastRMSNorm.load(
prefix=f"{prefix}.k_norm", prefix=f"{prefix}.k_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.k_norm.weight.data = self.k_norm.weight[
self.num_key_value_heads * rank : self.num_key_value_heads * (rank + 1)
]
else: else:
self.q_norm = None self.q_norm = None
self.k_norm = None self.k_norm = None
@ -221,13 +228,14 @@ class FlashCohereAttention(torch.nn.Module):
], ],
dim=1, dim=1,
) )
if self.use_qk_norm: if self.use_qk_norm:
query = self.q_norm(query.contiguous()) query, _ = self.q_norm(query.contiguous())
key = self.k_norm(key.contiguous()) key, _ = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size) key = key.view(-1, self.num_key_value_heads, self.head_size)
value = key.view(-1, self.num_key_value_heads, self.head_size) value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
@ -253,8 +261,8 @@ class FlashCohereAttention(torch.nn.Module):
paged_attention.attention( paged_attention.attention(
attn_output, attn_output,
query, query,
key, kv_cache[0],
value, kv_cache[1],
self.num_key_value_heads, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,