mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Fixing Falcon 40b
This commit is contained in:
parent
5c82dcd2bf
commit
cc84387877
@ -18,9 +18,10 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
@ -100,7 +101,9 @@ class RWConfig(PretrainedConfig):
|
||||
class FlashRWAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config, prefix, weights,
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
@ -109,12 +112,21 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
self.num_heads = self.num_heads //weights.process_group.size()
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
|
||||
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
weights=weights,
|
||||
bias=config.bias,
|
||||
)
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -204,26 +216,29 @@ class FlashRWAttention(torch.nn.Module):
|
||||
class FlashRWLargeAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config, prefix, weights,
|
||||
# num_heads,
|
||||
# num_heads_kv,
|
||||
# hidden_size,
|
||||
# bias,
|
||||
# process_group=None,
|
||||
reduce=True,
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.n_head
|
||||
num_heads_kv = config.n_head_kv
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
self.num_heads = num_heads // self.num_groups
|
||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||
process_group = weights.process_group
|
||||
|
||||
if process_group.size() > self.num_groups:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
@ -232,9 +247,17 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
||||
)
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
|
||||
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
weights=weights,
|
||||
bias=config.bias,
|
||||
)
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -331,12 +354,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(self, config, prefix, weights, reduce=True):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.act = torch.nn.functional.gelu
|
||||
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias)
|
||||
self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias)
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
|
||||
)
|
||||
self.dense_4h_to_h = load_row(
|
||||
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||
@ -351,20 +378,9 @@ class FlashRWLayer(nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
# num_heads,
|
||||
# num_heads_kv,
|
||||
# hidden_size,
|
||||
# bias,
|
||||
# layer_norm_eps,
|
||||
# parallel_attn,
|
||||
# process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
n_head = config.n_head
|
||||
n_head_kv = config.n_head_kv
|
||||
hidden_size = config.hidden_size
|
||||
bias = config.bias
|
||||
parallel_attn = config.parallel_attn
|
||||
self.parallel_attn = parallel_attn
|
||||
|
||||
@ -376,11 +392,6 @@ class FlashRWLayer(nn.Module):
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.self_attention = FlashRWAttention(
|
||||
# num_heads,
|
||||
# num_heads_kv,
|
||||
# hidden_size,
|
||||
# bias,
|
||||
# process_group=process_group,
|
||||
config,
|
||||
prefix=f"{prefix}.self_attention",
|
||||
weights=weights,
|
||||
@ -391,16 +402,16 @@ class FlashRWLayer(nn.Module):
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
) if not parallel_attn
|
||||
)
|
||||
if not parallel_attn
|
||||
else None
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(
|
||||
# hidden_size, bias, process_group=process_group, reduce=False
|
||||
config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
weights=weights,
|
||||
reduce=False
|
||||
reduce=False,
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
@ -461,11 +472,9 @@ class FlashRWLayer(nn.Module):
|
||||
|
||||
|
||||
class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config, prefix, weights
|
||||
):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
self.ln_attn = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_attn",
|
||||
weights=weights,
|
||||
@ -478,13 +487,13 @@ class FlashRWLargeLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.self_attention = FlashRWLargeAttention(
|
||||
config, prefix=f"{prefix}.self_attention", weights=weights,
|
||||
reduce=False,
|
||||
config,
|
||||
prefix=f"{prefix}.self_attention",
|
||||
weights=weights,
|
||||
)
|
||||
assert config.parallel_attn, "This version doesn't support non parallel_attn"
|
||||
|
||||
self.mlp = FlashMLP(
|
||||
config, prefix=f"{prefix}.mlp", weights=weights, reduce=False
|
||||
)
|
||||
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
@ -541,7 +550,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(
|
||||
layer_id, config, weights
|
||||
layer_id,
|
||||
config,
|
||||
weights
|
||||
# config.n_head,
|
||||
# config.n_head_kv,
|
||||
# config.hidden_size,
|
||||
@ -561,15 +572,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
elif config.model_type == "RefinedWeb":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(
|
||||
layer_id, config, weights
|
||||
# config.n_head,
|
||||
# config.n_head_kv,
|
||||
# config.hidden_size,
|
||||
# config.bias,
|
||||
# config.layer_norm_epsilon,
|
||||
# process_group,
|
||||
)
|
||||
FlashRWLargeLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
@ -310,11 +310,12 @@ try:
|
||||
|
||||
@staticmethod
|
||||
def static(dim, base, device):
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
|
||||
dtype=torch.float32) / dim))
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return PositionRotaryEmbedding(inv_freq)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def load(prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
@ -324,7 +325,6 @@ try:
|
||||
weights.dtype = dtype
|
||||
return PositionRotaryEmbedding(inv_freq)
|
||||
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
|
Loading…
Reference in New Issue
Block a user