mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix
This commit is contained in:
parent
4a02d3505f
commit
58a7719e02
@ -179,16 +179,23 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.use_qk_norm = config.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
rank = weights.process_group.rank()
|
||||
self.q_norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.q_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.q_norm.weight.data = self.q_norm.weight[
|
||||
self.num_heads * rank : self.num_heads * (rank + 1)
|
||||
]
|
||||
self.k_norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.k_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.k_norm.weight.data = self.k_norm.weight[
|
||||
self.num_key_value_heads * rank : self.num_key_value_heads * (rank + 1)
|
||||
]
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
@ -221,13 +228,14 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
query = self.q_norm(query.contiguous())
|
||||
key = self.k_norm(key.contiguous())
|
||||
query, _ = self.q_norm(query.contiguous())
|
||||
key, _ = self.k_norm(key.contiguous())
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
@ -253,8 +261,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
Loading…
Reference in New Issue
Block a user