add: support for falcon-10B architecture.

This commit is contained in:
Nilabhra 2024-04-15 13:52:20 +04:00
parent 80ba799c88
commit 22c005fac3

View File

@ -7,8 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.layers import (
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
@ -139,10 +138,7 @@ class FlashRWAttention(torch.nn.Module):
self.rope_theta = config.rope_theta
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=self.rope_theta,
device=weights.device,
config=config, dim=self.head_size, base=self.rope_theta, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
@ -480,6 +476,44 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
elif self.num_ln == 2:
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
else:
raise ValueError("Number of layer norms can either be 1 or 2.")
def forward(
self,
hidden_states,
residual,
):
if self.num_ln == 1:
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
return ln_hidden_states, ln_hidden_states, residual
elif self.num_ln == 2:
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
return ln_attn, ln_mlp, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
@ -524,7 +558,7 @@ class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
self.self_attention = FlashRWLargeAttention(