mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Large attention ?
This commit is contained in:
parent
d083d57d0d
commit
daf59b0582
@ -209,11 +209,12 @@ class FlashRWAttention(torch.nn.Module):
|
||||
class FlashRWLargeAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
config, prefix, weights,
|
||||
# num_heads,
|
||||
# num_heads_kv,
|
||||
# hidden_size,
|
||||
# bias,
|
||||
# process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
@ -221,46 +222,24 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
# self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
self.num_heads = num_heads // self.num_groups
|
||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
process_group = weights.process_group
|
||||
if process_group.size() > self.num_groups:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
)
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
if self.num_groups % process_group.size() != 0:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
||||
)
|
||||
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
|
||||
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -460,8 +439,6 @@ class FlashRWLayer(nn.Module):
|
||||
mlp_output = self.mlp(ln_hidden_states)
|
||||
intermediate = mlp_output + attn_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
@ -548,8 +525,6 @@ class FlashRWLargeLayer(nn.Module):
|
||||
|
||||
intermediate = attn_output + mlp_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
|
Loading…
Reference in New Issue
Block a user