This commit is contained in:
OlivierDehaene 2023-05-30 15:24:21 +02:00
parent 8c8d709994
commit 51b5db77f5
3 changed files with 64 additions and 25 deletions

View File

@ -201,7 +201,10 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel"]:
if sharded:
if FLASH_ATTENTION:
if config.alibi:
if config.alibi or (
config.model_type == "RefinedWebModel"
and config.n_head_kv != config.n_head
):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
model_id,

View File

@ -107,11 +107,11 @@ class FlashRWAttention(torch.nn.Module):
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = FastLinear(
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
@ -120,6 +120,7 @@ class FlashRWAttention(torch.nn.Module):
process_group=process_group,
reduce=reduce,
)
self.num_heads = self.num_heads // process_group.size()
def forward(
self,
@ -231,13 +232,18 @@ class FlashRWLargeAttention(torch.nn.Module):
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups *
self.head_size
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
if process_group.size() > self.num_groups:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups"
)
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.num_groups
@ -269,10 +275,13 @@ class FlashRWLargeAttention(torch.nn.Module):
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
# Split on group dimension
query, kv = qkv.split(
[self.num_heads, 2],
dim=2,
)
# Merge groups and heads
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
# Inplace rotary
@ -285,8 +294,12 @@ class FlashRWLargeAttention(torch.nn.Module):
layer_past[...] = kv
k, v = kv.split(1, dim=2)
# Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
# output
attn_output = torch.empty_like(query)
@ -314,8 +327,12 @@ class FlashRWLargeAttention(torch.nn.Module):
layer_past[layer_past_present_indices] = kv
k, v = layer_past.split(1, dim=2)
# Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
# output
attn_output = torch.empty_like(query)
@ -338,7 +355,9 @@ class FlashRWLargeAttention(torch.nn.Module):
None,
)
return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size))
return self.dense(
attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
)
class FlashMLP(nn.Module):
@ -389,7 +408,12 @@ class FlashRWLayer(nn.Module):
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWAttention(
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
)
self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps)
@ -397,7 +421,9 @@ class FlashRWLayer(nn.Module):
else None
)
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group
@ -473,10 +499,17 @@ class FlashRWLargeLayer(nn.Module):
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWLargeAttention(
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
)
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group
@ -492,8 +525,8 @@ class FlashRWLargeLayer(nn.Module):
layer_past_present_indices,
cu_seqlens_q,
):
ln_attn, _ = self.ln_attn(hidden_states)
ln_mlp, _ = self.ln_mlp(hidden_states)
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
# Self attention.
attn_output = self.self_attention(
@ -516,13 +549,11 @@ class FlashRWLargeLayer(nn.Module):
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None
return intermediate, residual
class FlashRWPreTrainedModel(PreTrainedModel):
config_class = RWConfig
supports_gradient_checkpointing = False
_no_split_modules = None
class FlashRWModel(FlashRWPreTrainedModel):
@ -559,7 +590,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
for _ in range(config.num_hidden_layers)
]
)
self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size)
self.cache_size = (
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
@ -574,7 +609,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
for _ in range(config.num_hidden_layers)
]
)
self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size)
self.cache_size = (
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
@ -582,8 +621,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self.head_size = self.h[0].self_attention.head_size
def post_load_weights(self, quantize: Optional[str] = None):
@ -629,7 +666,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
*self.cache_size
*self.cache_size,
)
)
layer_past_present_indices = None

View File

@ -113,7 +113,6 @@ class FlashRW(FlashCausalLM):
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
def __init__(
self,