mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fixing Falcon 40b
This commit is contained in:
parent
5c82dcd2bf
commit
cc84387877
@ -18,9 +18,10 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
get_linear
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
@ -100,7 +101,9 @@ class RWConfig(PretrainedConfig):
|
|||||||
class FlashRWAttention(torch.nn.Module):
|
class FlashRWAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config, prefix, weights,
|
config,
|
||||||
|
prefix,
|
||||||
|
weights,
|
||||||
reduce=True,
|
reduce=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -109,12 +112,21 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device)
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
dim=self.head_size, base=10000.0, device=weights.device
|
||||||
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
|
config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.bias,
|
||||||
|
)
|
||||||
|
self.dense = load_row(
|
||||||
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -204,26 +216,29 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
class FlashRWLargeAttention(torch.nn.Module):
|
class FlashRWLargeAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config, prefix, weights,
|
config,
|
||||||
# num_heads,
|
prefix,
|
||||||
# num_heads_kv,
|
weights,
|
||||||
# hidden_size,
|
|
||||||
# bias,
|
|
||||||
# process_group=None,
|
|
||||||
reduce=True,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
num_heads = config.n_head
|
||||||
|
num_heads_kv = config.n_head_kv
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device)
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
self.head_size, base=10000.0, device=weights.device
|
||||||
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||||
self.num_heads = num_heads // self.num_groups
|
self.num_heads = num_heads // self.num_groups
|
||||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
|
||||||
if process_group.size() > self.num_groups:
|
if process_group.size() > self.num_groups:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||||
@ -232,9 +247,17 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
||||||
)
|
)
|
||||||
|
self.num_groups = self.num_groups // process_group.size()
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
|
config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.bias,
|
||||||
|
)
|
||||||
|
self.dense = load_row(
|
||||||
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -331,12 +354,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
def __init__(self, config, prefix, weights, reduce=True):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = torch.nn.functional.gelu
|
self.act = torch.nn.functional.gelu
|
||||||
|
|
||||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias)
|
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||||
self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias)
|
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
|
||||||
|
)
|
||||||
|
self.dense_4h_to_h = load_row(
|
||||||
|
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||||
@ -351,20 +378,9 @@ class FlashRWLayer(nn.Module):
|
|||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
# num_heads,
|
|
||||||
# num_heads_kv,
|
|
||||||
# hidden_size,
|
|
||||||
# bias,
|
|
||||||
# layer_norm_eps,
|
|
||||||
# parallel_attn,
|
|
||||||
# process_group=None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
n_head = config.n_head
|
|
||||||
n_head_kv = config.n_head_kv
|
|
||||||
hidden_size = config.hidden_size
|
|
||||||
bias = config.bias
|
|
||||||
parallel_attn = config.parallel_attn
|
parallel_attn = config.parallel_attn
|
||||||
self.parallel_attn = parallel_attn
|
self.parallel_attn = parallel_attn
|
||||||
|
|
||||||
@ -376,11 +392,6 @@ class FlashRWLayer(nn.Module):
|
|||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
self.self_attention = FlashRWAttention(
|
self.self_attention = FlashRWAttention(
|
||||||
# num_heads,
|
|
||||||
# num_heads_kv,
|
|
||||||
# hidden_size,
|
|
||||||
# bias,
|
|
||||||
# process_group=process_group,
|
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.self_attention",
|
prefix=f"{prefix}.self_attention",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -391,16 +402,16 @@ class FlashRWLayer(nn.Module):
|
|||||||
prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
) if not parallel_attn
|
)
|
||||||
|
if not parallel_attn
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = FlashMLP(
|
self.mlp = FlashMLP(
|
||||||
# hidden_size, bias, process_group=process_group, reduce=False
|
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
reduce=False
|
reduce=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
@ -461,11 +472,9 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashRWLargeLayer(nn.Module):
|
class FlashRWLargeLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(self, layer_id, config, weights):
|
||||||
self,
|
|
||||||
config, prefix, weights
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
prefix = f"transformer.h.{layer_id}"
|
||||||
self.ln_attn = FastLayerNorm.load(
|
self.ln_attn = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.ln_attn",
|
prefix=f"{prefix}.ln_attn",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -478,13 +487,13 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.self_attention = FlashRWLargeAttention(
|
self.self_attention = FlashRWLargeAttention(
|
||||||
config, prefix=f"{prefix}.self_attention", weights=weights,
|
config,
|
||||||
reduce=False,
|
prefix=f"{prefix}.self_attention",
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
assert config.parallel_attn, "This version doesn't support non parallel_attn"
|
||||||
|
|
||||||
self.mlp = FlashMLP(
|
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
|
||||||
config, prefix=f"{prefix}.mlp", weights=weights, reduce=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
@ -541,7 +550,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLayer(
|
FlashRWLayer(
|
||||||
layer_id, config, weights
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights
|
||||||
# config.n_head,
|
# config.n_head,
|
||||||
# config.n_head_kv,
|
# config.n_head_kv,
|
||||||
# config.hidden_size,
|
# config.hidden_size,
|
||||||
@ -561,15 +572,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
elif config.model_type == "RefinedWeb":
|
elif config.model_type == "RefinedWeb":
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLargeLayer(
|
FlashRWLargeLayer(layer_id, config, weights)
|
||||||
layer_id, config, weights
|
|
||||||
# config.n_head,
|
|
||||||
# config.n_head_kv,
|
|
||||||
# config.hidden_size,
|
|
||||||
# config.bias,
|
|
||||||
# config.layer_norm_epsilon,
|
|
||||||
# process_group,
|
|
||||||
)
|
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -310,11 +310,12 @@ try:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def static(dim, base, device):
|
def static(dim, base, device):
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
|
inv_freq = 1.0 / (
|
||||||
dtype=torch.float32) / dim))
|
base
|
||||||
|
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
|
)
|
||||||
return PositionRotaryEmbedding(inv_freq)
|
return PositionRotaryEmbedding(inv_freq)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(prefix, weights):
|
def load(prefix, weights):
|
||||||
# XXX: Always load this in float32 !
|
# XXX: Always load this in float32 !
|
||||||
@ -324,7 +325,6 @@ try:
|
|||||||
weights.dtype = dtype
|
weights.dtype = dtype
|
||||||
return PositionRotaryEmbedding(inv_freq)
|
return PositionRotaryEmbedding(inv_freq)
|
||||||
|
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
Loading…
Reference in New Issue
Block a user