mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix
This commit is contained in:
parent
4a02d3505f
commit
58a7719e02
@ -179,16 +179,23 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.use_qk_norm = config.use_qk_norm
|
self.use_qk_norm = config.use_qk_norm
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
|
rank = weights.process_group.rank()
|
||||||
self.q_norm = FastRMSNorm.load(
|
self.q_norm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.q_norm",
|
prefix=f"{prefix}.q_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
|
self.q_norm.weight.data = self.q_norm.weight[
|
||||||
|
self.num_heads * rank : self.num_heads * (rank + 1)
|
||||||
|
]
|
||||||
self.k_norm = FastRMSNorm.load(
|
self.k_norm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.k_norm",
|
prefix=f"{prefix}.k_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
|
self.k_norm.weight.data = self.k_norm.weight[
|
||||||
|
self.num_key_value_heads * rank : self.num_key_value_heads * (rank + 1)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self.q_norm = None
|
self.q_norm = None
|
||||||
self.k_norm = None
|
self.k_norm = None
|
||||||
@ -221,13 +228,14 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
query = self.q_norm(query.contiguous())
|
query, _ = self.q_norm(query.contiguous())
|
||||||
key = self.k_norm(key.contiguous())
|
key, _ = self.k_norm(key.contiguous())
|
||||||
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
value = key.view(-1, self.num_key_value_heads, self.head_size)
|
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
@ -253,8 +261,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
paged_attention.attention(
|
paged_attention.attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
key,
|
kv_cache[0],
|
||||||
value,
|
kv_cache[1],
|
||||||
self.num_key_value_heads,
|
self.num_key_value_heads,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
Loading…
Reference in New Issue
Block a user