add: support for falcon-10B architecture.

This commit is contained in:
Nilabhra 2024-04-15 13:52:20 +04:00
parent 762dbf3f19
commit f2b3d8d7ed

View File

@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -48,6 +47,7 @@ class RWConfig(PretrainedConfig):
hidden_size=64, hidden_size=64,
num_hidden_layers=None, num_hidden_layers=None,
num_attention_heads=None, num_attention_heads=None,
num_ln_in_prallel_attention=None,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
use_cache=True, use_cache=True,
@ -61,6 +61,7 @@ class RWConfig(PretrainedConfig):
new_decoder_architecture=None, new_decoder_architecture=None,
bias=False, bias=False,
parallel_attn=False, parallel_attn=False,
rope_theta=10_000.0,
**kwargs, **kwargs,
): ):
if alibi: if alibi:
@ -71,6 +72,7 @@ class RWConfig(PretrainedConfig):
self.model_type = model_type self.model_type = model_type
self.alibi = False self.alibi = False
self.rotary = True self.rotary = True
self.rope_theta = rope_theta
self.vocab_size = vocab_size self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg # Backward compatibility with n_embed kwarg
@ -87,6 +89,7 @@ class RWConfig(PretrainedConfig):
else kwargs.pop("n_head", 8) else kwargs.pop("n_head", 8)
) )
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.num_ln_in_parallel_attention = num_ln_in_prallel_attention
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.use_cache = use_cache self.use_cache = use_cache
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
@ -128,9 +131,10 @@ class FlashRWAttention(torch.nn.Module):
self.num_heads_kv = config.n_head_kv self.num_heads_kv = config.n_head_kv
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
self.rope_theta = config.rope_theta
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device config=config, dim=self.head_size, base=self.rope_theta, device=weights.device
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
@ -240,6 +244,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.num_groups = num_groups self.num_groups = num_groups
self.rope_theta = config.rope_theta
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device config=config, dim=self.head_size, base=10000.0, device=weights.device
@ -253,7 +258,7 @@ class FlashRWLargeAttention(torch.nn.Module):
if process_group.size() > self.num_groups: if process_group.size() > self.num_groups:
raise NotImplementedError( raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups" "Tensor Parallelism is not implemented for world_size > n groups"
) )
if self.num_groups % process_group.size() != 0: if self.num_groups % process_group.size() != 0:
raise NotImplementedError( raise NotImplementedError(
@ -455,6 +460,7 @@ class FlashRWLayer(nn.Module):
max_s, max_s,
) )
if self.post_attention_layernorm is not None:
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual hidden_states, residual
) )
@ -463,11 +469,18 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual return mlp_output, residual
class FlashRWLayerNorm(nn.Module):
class FlashRWLargeLayer(nn.Module): def __init__(self, config, prefix, weights):
def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" self.num_ln = config.num_ln_in_parallel_attn
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
elif self.num_ln == 2:
self.ln_attn = FastLayerNorm.load( self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn", prefix=f"{prefix}.ln_attn",
weights=weights, weights=weights,
@ -478,6 +491,29 @@ class FlashRWLargeLayer(nn.Module):
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
else:
raise ValueError("Number of layer norms can either be 1 or 2.")
def forward(
self,
hidden_states,
residual,
):
if self.num_ln == 1:
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
return ln_hidden_states, ln_hidden_states, residual
elif self.num_ln == 2:
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
return ln_attn, ln_mlp, residual
class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
self.self_attention = FlashRWLargeAttention( self.self_attention = FlashRWLargeAttention(
config, config,
@ -503,8 +539,8 @@ class FlashRWLargeLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) # Layer norm.
ln_mlp, _ = self.ln_mlp(residual) ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
# Self attention. # Self attention.
attn_output = self.self_attention( attn_output = self.self_attention(