diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 41be97fc..6baaa5ff 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -257,6 +257,10 @@ class FlashRWLargeAttention(torch.nn.Module): self.num_groups = self.num_groups // process_group.size() + self.num_heads_config = num_heads + self.num_heads_kv_config = num_heads_kv + self.num_groups = 64 + def forward( self, hidden_states, @@ -268,37 +272,56 @@ class FlashRWLargeAttention(torch.nn.Module): layer_past_present_indices, cu_seqlens_q, ): + cu_shape = hidden_states.shape[0] + qkv = self.query_key_value(hidden_states) - qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) + qkv = qkv.view(cu_shape, -1, self.num_heads_config // self.num_heads_kv_config +2, 64) + q = qkv[:, :, :-2] + k = qkv[:, :, [-2]] + v = qkv[:, :, [-1]] - # Split query from key_value - query, kv = qkv.split( - [self.num_heads, 2], - dim=2, - ) + k = torch.broadcast_to(k, q.shape) + v = torch.broadcast_to(v, q.shape) - # Prepare query and key_value for indexing - query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) - kv = kv.transpose(1, 2) + q = q.reshape(cu_shape, -1, self.head_size) + k = k.reshape(cu_shape, -1, self.head_size) + v = v.reshape(cu_shape, -1, self.head_size) + + logger.error(k.shape) + + # qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) + # + # # Split query from key_value + # query, kv = qkv.split( + # [self.num_heads, 2], + # dim=2, + # ) + # + # # Prepare query and key_value for indexing + # query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) + # kv = kv.transpose(1, 2) # Inplace rotary - self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, 0], cos, sin) + self.rotary_emb(q, cos, sin) + self.rotary_emb(k, cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - layer_past[...] = kv - k, v = kv.split(1, dim=1) + # layer_past[...] = kv + # k, v = kv.split(1, dim=1) # Expand to query shape - k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + # k = k.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + # v = v.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + + layer_past[:, 0] = k + layer_past[:, 1] = v # output - attn_output = torch.empty_like(query) + attn_output = torch.empty_like(q) # flash attention flash_attn_cuda.fwd( - query, + q, k, v, attn_output, @@ -317,19 +340,22 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = kv - k, v = layer_past.split(1, dim=1) + # layer_past[layer_past_present_indices] = kv + # k, v = layer_past.split(1, dim=1) # Expand to query shape - k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + # k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + # v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + + layer_past[layer_past_present_indices, 0] = k + layer_past[layer_past_present_indices, 1] = v # output - attn_output = torch.empty_like(query) + attn_output = torch.empty_like(q) # flash attention flash_attn_cuda.fwd( - query, - k, - v, + q, + layer_past[:, 0], + layer_past[:, 1], attn_output, cu_seqlens_q, cu_seqlens, @@ -344,7 +370,7 @@ class FlashRWLargeAttention(torch.nn.Module): None, ) - return self.dense(attn_output.view(-1, self.num_heads * self.num_groups * self.head_size)) + return self.dense(attn_output.view(cu_shape, -1)) class FlashMLP(nn.Module): @@ -498,8 +524,8 @@ class FlashRWLargeLayer(nn.Module): layer_past_present_indices, cu_seqlens_q, ): - ln_attn, residual = self.ln_attn(hidden_states, residual) - ln_mlp, _ = self.ln_mlp(hidden_states, residual) + ln_attn, _ = self.ln_attn(hidden_states) + ln_mlp, _ = self.ln_mlp(hidden_states) # Self attention. attn_output = self.self_attention( @@ -522,7 +548,7 @@ class FlashRWLargeLayer(nn.Module): if self.process_group is not None: torch.distributed.all_reduce(intermediate, group=self.process_group) - return intermediate, residual + return intermediate + hidden_states, None class FlashRWPreTrainedModel(PreTrainedModel):