This commit is contained in:
OlivierDehaene 2023-05-30 15:24:21 +02:00
parent 8c8d709994
commit 51b5db77f5
3 changed files with 64 additions and 25 deletions

View File

@ -201,7 +201,10 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel"]: if model_type in ["RefinedWeb", "RefinedWebModel"]:
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config.alibi: if config.alibi or (
config.model_type == "RefinedWebModel"
and config.n_head_kv != config.n_head
):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded( return FlashRWSharded(
model_id, model_id,

View File

@ -107,11 +107,11 @@ class FlashRWAttention(torch.nn.Module):
) )
self.dense = FastLinear(hidden_size, hidden_size, bias=bias) self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else: else:
self.num_heads = self.num_heads // process_group.size() self.query_key_value = TensorParallelColumnLinear(
self.query_key_value = FastLinear(
hidden_size, hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv), self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias, bias=bias,
process_group=process_group,
) )
self.dense = TensorParallelRowLinear( self.dense = TensorParallelRowLinear(
hidden_size, hidden_size,
@ -120,6 +120,7 @@ class FlashRWAttention(torch.nn.Module):
process_group=process_group, process_group=process_group,
reduce=reduce, reduce=reduce,
) )
self.num_heads = self.num_heads // process_group.size()
def forward( def forward(
self, self,
@ -231,13 +232,18 @@ class FlashRWLargeAttention(torch.nn.Module):
if process_group is None: if process_group is None:
self.query_key_value = FastLinear( self.query_key_value = FastLinear(
hidden_size, hidden_size,
self.num_groups * self.num_groups
self.head_size * self.head_size
* (self.num_heads + 2 * self.num_heads_kv), * (self.num_heads + 2 * self.num_heads_kv),
bias=bias, bias=bias,
) )
self.dense = FastLinear(hidden_size, hidden_size, bias=bias) self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else: else:
if process_group.size() > self.num_groups:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups"
)
self.query_key_value = TensorParallelColumnLinear( self.query_key_value = TensorParallelColumnLinear(
hidden_size, hidden_size,
self.num_groups self.num_groups
@ -269,10 +275,13 @@ class FlashRWLargeAttention(torch.nn.Module):
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
# Split on group dimension
query, kv = qkv.split( query, kv = qkv.split(
[self.num_heads, 2], [self.num_heads, 2],
dim=2, dim=2,
) )
# Merge groups and heads
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
# Inplace rotary # Inplace rotary
@ -285,8 +294,12 @@ class FlashRWLargeAttention(torch.nn.Module):
layer_past[...] = kv layer_past[...] = kv
k, v = kv.split(1, dim=2) k, v = kv.split(1, dim=2)
# Expand to query shape # Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) -1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -314,8 +327,12 @@ class FlashRWLargeAttention(torch.nn.Module):
layer_past[layer_past_present_indices] = kv layer_past[layer_past_present_indices] = kv
k, v = layer_past.split(1, dim=2) k, v = layer_past.split(1, dim=2)
# Expand to query shape # Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) -1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -338,7 +355,9 @@ class FlashRWLargeAttention(torch.nn.Module):
None, None,
) )
return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)) return self.dense(
attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
)
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
@ -389,7 +408,12 @@ class FlashRWLayer(nn.Module):
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWAttention( self.self_attention = FlashRWAttention(
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
) )
self.post_attention_layernorm = ( self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps) FastLayerNorm(hidden_size, eps=layer_norm_eps)
@ -397,7 +421,9 @@ class FlashRWLayer(nn.Module):
else None else None
) )
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False) self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group self.process_group = process_group
@ -473,10 +499,17 @@ class FlashRWLargeLayer(nn.Module):
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWLargeAttention( self.self_attention = FlashRWLargeAttention(
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
) )
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False) self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group self.process_group = process_group
@ -492,8 +525,8 @@ class FlashRWLargeLayer(nn.Module):
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
ln_attn, _ = self.ln_attn(hidden_states) ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(hidden_states) ln_mlp, _ = self.ln_mlp(residual)
# Self attention. # Self attention.
attn_output = self.self_attention( attn_output = self.self_attention(
@ -516,13 +549,11 @@ class FlashRWLargeLayer(nn.Module):
if self.process_group is not None: if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None return intermediate, residual
class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWPreTrainedModel(PreTrainedModel):
config_class = RWConfig config_class = RWConfig
supports_gradient_checkpointing = False
_no_split_modules = None
class FlashRWModel(FlashRWPreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel):
@ -559,7 +590,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
) )
self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size) self.cache_size = (
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
@ -574,7 +609,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
) )
self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size) self.cache_size = (
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"model_type {config.model_type} is not supported." f"model_type {config.model_type} is not supported."
@ -582,8 +621,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self.head_size = self.h[0].self_attention.head_size self.head_size = self.h[0].self_attention.head_size
def post_load_weights(self, quantize: Optional[str] = None): def post_load_weights(self, quantize: Optional[str] = None):
@ -629,7 +666,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
len(hidden_states) len(hidden_states)
if pre_allocate_past_size is None if pre_allocate_past_size is None
else pre_allocate_past_size, else pre_allocate_past_size,
*self.cache_size *self.cache_size,
) )
) )
layer_past_present_indices = None layer_past_present_indices = None

View File

@ -113,7 +113,6 @@ class FlashRW(FlashCausalLM):
model.post_load_weights(quantize) model.post_load_weights(quantize)
class FlashRWSharded(FlashRW): class FlashRWSharded(FlashRW):
def __init__( def __init__(
self, self,