From 8c8d70999469a127fdfce5c40d587171a13e7cbd Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 30 May 2023 15:09:49 +0200 Subject: [PATCH] 40b working --- .../custom_modeling/flash_rw_modeling.py | 86 ++++++------------- 1 file changed, 26 insertions(+), 60 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6baaa5ff..f617ec20 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,7 +1,6 @@ import torch import torch.distributed -from loguru import logger from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig @@ -257,10 +256,6 @@ class FlashRWLargeAttention(torch.nn.Module): self.num_groups = self.num_groups // process_group.size() - self.num_heads_config = num_heads - self.num_heads_kv_config = num_heads_kv - self.num_groups = 64 - def forward( self, hidden_states, @@ -272,56 +267,32 @@ class FlashRWLargeAttention(torch.nn.Module): layer_past_present_indices, cu_seqlens_q, ): - cu_shape = hidden_states.shape[0] - qkv = self.query_key_value(hidden_states) - qkv = qkv.view(cu_shape, -1, self.num_heads_config // self.num_heads_kv_config +2, 64) - q = qkv[:, :, :-2] - k = qkv[:, :, [-2]] - v = qkv[:, :, [-1]] - - k = torch.broadcast_to(k, q.shape) - v = torch.broadcast_to(v, q.shape) - - q = q.reshape(cu_shape, -1, self.head_size) - k = k.reshape(cu_shape, -1, self.head_size) - v = v.reshape(cu_shape, -1, self.head_size) - - logger.error(k.shape) - - # qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) - # - # # Split query from key_value - # query, kv = qkv.split( - # [self.num_heads, 2], - # dim=2, - # ) - # - # # Prepare query and key_value for indexing - # query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) - # kv = kv.transpose(1, 2) + qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) + query, kv = qkv.split( + [self.num_heads, 2], + dim=2, + ) + query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) # Inplace rotary - self.rotary_emb(q, cos, sin) - self.rotary_emb(k, cos, sin) + self.rotary_emb(query, cos, sin) + self.rotary_emb(kv[:, :, 0], cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - # layer_past[...] = kv - # k, v = kv.split(1, dim=1) + layer_past[...] = kv + k, v = kv.split(1, dim=2) # Expand to query shape - # k = k.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - # v = v.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - - layer_past[:, 0] = k - layer_past[:, 1] = v + k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) # output - attn_output = torch.empty_like(q) + attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( - q, + query, k, v, attn_output, @@ -340,22 +311,19 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - # layer_past[layer_past_present_indices] = kv - # k, v = layer_past.split(1, dim=1) + layer_past[layer_past_present_indices] = kv + k, v = layer_past.split(1, dim=2) # Expand to query shape - # k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - # v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - - layer_past[layer_past_present_indices, 0] = k - layer_past[layer_past_present_indices, 1] = v + k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) # output - attn_output = torch.empty_like(q) + attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( - q, - layer_past[:, 0], - layer_past[:, 1], + query, + k, + v, attn_output, cu_seqlens_q, cu_seqlens, @@ -370,7 +338,7 @@ class FlashRWLargeAttention(torch.nn.Module): None, ) - return self.dense(attn_output.view(cu_shape, -1)) + return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)) class FlashMLP(nn.Module): @@ -591,7 +559,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for _ in range(config.num_hidden_layers) ] ) - self.kv_size = self.h[0].self_attention.num_heads_kv + self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size) elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ @@ -606,7 +574,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for _ in range(config.num_hidden_layers) ] ) - self.kv_size = self.h[0].self_attention.num_groups + self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size) else: raise NotImplementedError( f"model_type {config.model_type} is not supported." @@ -661,9 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel): len(hidden_states) if pre_allocate_past_size is None else pre_allocate_past_size, - 2, - self.kv_size, - self.head_size, + *self.cache_size ) ) layer_past_present_indices = None