40b working

This commit is contained in:
OlivierDehaene 2023-05-30 15:09:49 +02:00
parent bbb1d9e704
commit 8c8d709994

View File

@ -1,7 +1,6 @@
import torch import torch
import torch.distributed import torch.distributed
from loguru import logger
from torch import nn from torch import nn
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
@ -257,10 +256,6 @@ class FlashRWLargeAttention(torch.nn.Module):
self.num_groups = self.num_groups // process_group.size() self.num_groups = self.num_groups // process_group.size()
self.num_heads_config = num_heads
self.num_heads_kv_config = num_heads_kv
self.num_groups = 64
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -272,56 +267,32 @@ class FlashRWLargeAttention(torch.nn.Module):
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
cu_shape = hidden_states.shape[0]
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(cu_shape, -1, self.num_heads_config // self.num_heads_kv_config +2, 64) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
q = qkv[:, :, :-2] query, kv = qkv.split(
k = qkv[:, :, [-2]] [self.num_heads, 2],
v = qkv[:, :, [-1]] dim=2,
)
k = torch.broadcast_to(k, q.shape) query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
v = torch.broadcast_to(v, q.shape)
q = q.reshape(cu_shape, -1, self.head_size)
k = k.reshape(cu_shape, -1, self.head_size)
v = v.reshape(cu_shape, -1, self.head_size)
logger.error(k.shape)
# qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
#
# # Split query from key_value
# query, kv = qkv.split(
# [self.num_heads, 2],
# dim=2,
# )
#
# # Prepare query and key_value for indexing
# query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
# kv = kv.transpose(1, 2)
# Inplace rotary # Inplace rotary
self.rotary_emb(q, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(k, cos, sin) self.rotary_emb(kv[:, :, 0], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
# layer_past[...] = kv layer_past[...] = kv
# k, v = kv.split(1, dim=1) k, v = kv.split(1, dim=2)
# Expand to query shape # Expand to query shape
# k = k.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
# v = v.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
layer_past[:, 0] = k
layer_past[:, 1] = v
# output # output
attn_output = torch.empty_like(q) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
q, query,
k, k,
v, v,
attn_output, attn_output,
@ -340,22 +311,19 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
# layer_past[layer_past_present_indices] = kv layer_past[layer_past_present_indices] = kv
# k, v = layer_past.split(1, dim=1) k, v = layer_past.split(1, dim=2)
# Expand to query shape # Expand to query shape
# k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
# v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
layer_past[layer_past_present_indices, 0] = k
layer_past[layer_past_present_indices, 1] = v
# output # output
attn_output = torch.empty_like(q) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
q, query,
layer_past[:, 0], k,
layer_past[:, 1], v,
attn_output, attn_output,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens, cu_seqlens,
@ -370,7 +338,7 @@ class FlashRWLargeAttention(torch.nn.Module):
None, None,
) )
return self.dense(attn_output.view(cu_shape, -1)) return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size))
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
@ -591,7 +559,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
) )
self.kv_size = self.h[0].self_attention.num_heads_kv self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size)
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
@ -606,7 +574,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
) )
self.kv_size = self.h[0].self_attention.num_groups self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"model_type {config.model_type} is not supported." f"model_type {config.model_type} is not supported."
@ -661,9 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
len(hidden_states) len(hidden_states)
if pre_allocate_past_size is None if pre_allocate_past_size is None
else pre_allocate_past_size, else pre_allocate_past_size,
2, *self.cache_size
self.kv_size,
self.head_size,
) )
) )
layer_past_present_indices = None layer_past_present_indices = None