mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
add: support for falcon-10B architecture.
This commit is contained in:
parent
d3d83e7d04
commit
46ada47963
@ -481,6 +481,44 @@ class FlashRWLayer(nn.Module):
|
||||
|
||||
return mlp_output, residual
|
||||
|
||||
class FlashRWLayerNorm(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.num_ln = config.num_ln_in_parallel_attn
|
||||
|
||||
if self.num_ln == 1:
|
||||
self.input_ln = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
elif self.num_ln == 2:
|
||||
self.ln_attn = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_attn",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.ln_mlp = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_mlp",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Number of layer norms can either be 1 or 2.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
):
|
||||
if self.num_ln == 1:
|
||||
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
|
||||
return ln_hidden_states, ln_hidden_states, residual
|
||||
elif self.num_ln == 2:
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(residual)
|
||||
return ln_attn, ln_mlp, residual
|
||||
|
||||
|
||||
class FlashRWLayerNorm(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
@ -564,7 +602,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
|
||||
|
||||
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
||||
|
||||
self.self_attention = FlashRWLargeAttention(
|
||||
|
Loading…
Reference in New Issue
Block a user