This commit is contained in:
Nicolas Patry 2023-08-30 08:36:09 +00:00
parent 211b54ac41
commit 7b88baddf7

View File

@ -241,19 +241,21 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_size = config.hidden_size
num_heads = config.n_head
num_heads_kv = config.n_head_kv
# num_heads_kv = config.n_head_kv
num_groups = config.n_head_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_groups = num_groups
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2)
# self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups
# self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group
if process_group.size() > self.num_groups:
@ -264,6 +266,7 @@ class FlashRWLargeAttention(torch.nn.Module):
raise NotImplementedError(
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
)
self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(