From daf59b0582c048276a5b470c10cce4645e5cff3d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 6 Jun 2023 11:08:25 +0000 Subject: [PATCH] Large attention ? --- .../custom_modeling/flash_rw_modeling.py | 61 ++++++------------- 1 file changed, 18 insertions(+), 43 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 9b175cf9..34a037ab 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -209,11 +209,12 @@ class FlashRWAttention(torch.nn.Module): class FlashRWLargeAttention(torch.nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=None, + config, prefix, weights, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # process_group=None, reduce=True, ): super().__init__() @@ -221,46 +222,24 @@ class FlashRWLargeAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - # self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) - self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device) self.softmax_scale = self.head_size ** (-0.5) self.num_groups = num_heads // (num_heads_kv * 2) self.num_heads = num_heads // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups - - if process_group is None: - self.query_key_value = FastLinear( - hidden_size, - self.num_groups - * self.head_size - * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, + process_group = weights.process_group + if process_group.size() > self.num_groups: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for world_size > n groups" ) - self.dense = FastLinear(hidden_size, hidden_size, bias=bias) - else: - if process_group.size() > self.num_groups: - raise NotImplementedError( - f"Tensor Parallelism is not implemented for world_size > n groups" - ) - - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - self.num_groups - * self.head_size - * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - process_group=process_group, - ) - self.dense = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, + if self.num_groups % process_group.size() != 0: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" ) - self.num_groups = self.num_groups // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) + self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) def forward( self, @@ -460,9 +439,7 @@ class FlashRWLayer(nn.Module): mlp_output = self.mlp(ln_hidden_states) intermediate = mlp_output + attn_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual else: @@ -548,9 +525,7 @@ class FlashRWLargeLayer(nn.Module): intermediate = attn_output + mlp_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual