Fixing Falcon 40b

This commit is contained in:
Nicolas Patry 2023-06-07 16:17:06 +02:00
parent 5c82dcd2bf
commit cc84387877
2 changed files with 69 additions and 66 deletions

View File

@ -18,9 +18,10 @@ from text_generation_server.utils.layers import (
TensorParallelHead, TensorParallelHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear get_linear,
) )
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1) weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
@ -100,7 +101,9 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module): class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, prefix, weights, config,
prefix,
weights,
reduce=True, reduce=True,
): ):
super().__init__() super().__init__()
@ -109,12 +112,21 @@ class FlashRWAttention(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device) self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads //weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) self.query_key_value = TensorParallelColumnLinear.load(
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
def forward( def forward(
self, self,
@ -204,26 +216,29 @@ class FlashRWAttention(torch.nn.Module):
class FlashRWLargeAttention(torch.nn.Module): class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, prefix, weights, config,
# num_heads, prefix,
# num_heads_kv, weights,
# hidden_size,
# bias,
# process_group=None,
reduce=True,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size
num_heads = config.n_head
num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device) self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2) self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group process_group = weights.process_group
if process_group.size() > self.num_groups: if process_group.size() > self.num_groups:
raise NotImplementedError( raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups" f"Tensor Parallelism is not implemented for world_size > n groups"
@ -232,9 +247,17 @@ class FlashRWLargeAttention(torch.nn.Module):
raise NotImplementedError( raise NotImplementedError(
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
) )
self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) self.query_key_value = TensorParallelColumnLinear.load(
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
def forward( def forward(
self, self,
@ -331,12 +354,16 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, config, prefix, weights, reduce=True): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias) self.dense_h_to_4h = TensorParallelColumnLinear.load(
self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias) config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
)
self.dense_4h_to_h = load_row(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.dense_h_to_4h(hidden_states)
@ -351,20 +378,9 @@ class FlashRWLayer(nn.Module):
layer_id, layer_id,
config, config,
weights, weights,
# num_heads,
# num_heads_kv,
# hidden_size,
# bias,
# layer_norm_eps,
# parallel_attn,
# process_group=None,
): ):
super().__init__() super().__init__()
n_head = config.n_head
n_head_kv = config.n_head_kv
hidden_size = config.hidden_size
bias = config.bias
parallel_attn = config.parallel_attn parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
@ -376,11 +392,6 @@ class FlashRWLayer(nn.Module):
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.self_attention = FlashRWAttention( self.self_attention = FlashRWAttention(
# num_heads,
# num_heads_kv,
# hidden_size,
# bias,
# process_group=process_group,
config, config,
prefix=f"{prefix}.self_attention", prefix=f"{prefix}.self_attention",
weights=weights, weights=weights,
@ -391,16 +402,16 @@ class FlashRWLayer(nn.Module):
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) if not parallel_attn )
if not parallel_attn
else None else None
) )
self.mlp = FlashMLP( self.mlp = FlashMLP(
# hidden_size, bias, process_group=process_group, reduce=False
config, config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
weights=weights, weights=weights,
reduce=False reduce=False,
) )
self.process_group = weights.process_group self.process_group = weights.process_group
@ -461,11 +472,9 @@ class FlashRWLayer(nn.Module):
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
config, prefix, weights
):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}"
self.ln_attn = FastLayerNorm.load( self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn", prefix=f"{prefix}.ln_attn",
weights=weights, weights=weights,
@ -478,13 +487,13 @@ class FlashRWLargeLayer(nn.Module):
) )
self.self_attention = FlashRWLargeAttention( self.self_attention = FlashRWLargeAttention(
config, prefix=f"{prefix}.self_attention", weights=weights, config,
reduce=False, prefix=f"{prefix}.self_attention",
weights=weights,
) )
assert config.parallel_attn, "This version doesn't support non parallel_attn"
self.mlp = FlashMLP( self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
config, prefix=f"{prefix}.mlp", weights=weights, reduce=False
)
self.process_group = weights.process_group self.process_group = weights.process_group
@ -541,7 +550,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer( FlashRWLayer(
layer_id, config, weights layer_id,
config,
weights
# config.n_head, # config.n_head,
# config.n_head_kv, # config.n_head_kv,
# config.hidden_size, # config.hidden_size,
@ -561,15 +572,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer( FlashRWLargeLayer(layer_id, config, weights)
layer_id, config, weights
# config.n_head,
# config.n_head_kv,
# config.hidden_size,
# config.bias,
# config.layer_norm_epsilon,
# process_group,
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )

View File

@ -310,11 +310,12 @@ try:
@staticmethod @staticmethod
def static(dim, base, device): def static(dim, base, device):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, inv_freq = 1.0 / (
dtype=torch.float32) / dim)) base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return PositionRotaryEmbedding(inv_freq) return PositionRotaryEmbedding(inv_freq)
@staticmethod @staticmethod
def load(prefix, weights): def load(prefix, weights):
# XXX: Always load this in float32 ! # XXX: Always load this in float32 !
@ -324,7 +325,6 @@ try:
weights.dtype = dtype weights.dtype = dtype
return PositionRotaryEmbedding(inv_freq) return PositionRotaryEmbedding(inv_freq)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)