mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Large attention ?
This commit is contained in:
parent
d083d57d0d
commit
daf59b0582
@ -209,11 +209,12 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
class FlashRWLargeAttention(torch.nn.Module):
|
class FlashRWLargeAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
config, prefix, weights,
|
||||||
num_heads_kv,
|
# num_heads,
|
||||||
hidden_size,
|
# num_heads_kv,
|
||||||
bias,
|
# hidden_size,
|
||||||
process_group=None,
|
# bias,
|
||||||
|
# process_group=None,
|
||||||
reduce=True,
|
reduce=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -221,46 +222,24 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
# self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device)
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||||
self.num_heads = num_heads // self.num_groups
|
self.num_heads = num_heads // self.num_groups
|
||||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||||
|
process_group = weights.process_group
|
||||||
if process_group is None:
|
if process_group.size() > self.num_groups:
|
||||||
self.query_key_value = FastLinear(
|
raise NotImplementedError(
|
||||||
hidden_size,
|
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||||
self.num_groups
|
|
||||||
* self.head_size
|
|
||||||
* (self.num_heads + 2 * self.num_heads_kv),
|
|
||||||
bias=bias,
|
|
||||||
)
|
)
|
||||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
if self.num_groups % process_group.size() != 0:
|
||||||
else:
|
raise NotImplementedError(
|
||||||
if process_group.size() > self.num_groups:
|
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
||||||
raise NotImplementedError(
|
|
||||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear(
|
|
||||||
hidden_size,
|
|
||||||
self.num_groups
|
|
||||||
* self.head_size
|
|
||||||
* (self.num_heads + 2 * self.num_heads_kv),
|
|
||||||
bias=bias,
|
|
||||||
process_group=process_group,
|
|
||||||
)
|
|
||||||
self.dense = TensorParallelRowLinear(
|
|
||||||
hidden_size,
|
|
||||||
hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
process_group=process_group,
|
|
||||||
reduce=reduce,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_groups // process_group.size()
|
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
|
||||||
|
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -460,9 +439,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
intermediate = mlp_output + attn_output
|
intermediate = mlp_output + attn_output
|
||||||
|
|
||||||
# Only reduce once and after the addition instead of once per layer
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
if self.process_group is not None:
|
|
||||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
|
||||||
|
|
||||||
return intermediate, residual
|
return intermediate, residual
|
||||||
else:
|
else:
|
||||||
@ -548,9 +525,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
|
|
||||||
intermediate = attn_output + mlp_output
|
intermediate = attn_output + mlp_output
|
||||||
|
|
||||||
# Only reduce once and after the addition instead of once per layer
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
if self.process_group is not None:
|
|
||||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
|
||||||
|
|
||||||
return intermediate, residual
|
return intermediate, residual
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user