Large attention ?

This commit is contained in:
Ubuntu 2023-06-06 11:08:25 +00:00
parent d083d57d0d
commit daf59b0582

View File

@ -209,11 +209,12 @@ class FlashRWAttention(torch.nn.Module):
class FlashRWLargeAttention(torch.nn.Module): class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, config, prefix, weights,
num_heads_kv, # num_heads,
hidden_size, # num_heads_kv,
bias, # hidden_size,
process_group=None, # bias,
# process_group=None,
reduce=True, reduce=True,
): ):
super().__init__() super().__init__()
@ -221,46 +222,24 @@ class FlashRWLargeAttention(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
# self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device)
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2) self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
if process_group.size() > self.num_groups: if process_group.size() > self.num_groups:
raise NotImplementedError( raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups" f"Tensor Parallelism is not implemented for world_size > n groups"
) )
if self.num_groups % process_group.size() != 0:
self.query_key_value = TensorParallelColumnLinear( raise NotImplementedError(
hidden_size, f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
) )
self.num_groups = self.num_groups // process_group.size() self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
def forward( def forward(
self, self,
@ -460,8 +439,6 @@ class FlashRWLayer(nn.Module):
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual
@ -548,8 +525,6 @@ class FlashRWLargeLayer(nn.Module):
intermediate = attn_output + mlp_output intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual