mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add: support for falcon-10B architecture.
This commit is contained in:
parent
762dbf3f19
commit
f2b3d8d7ed
@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -48,6 +47,7 @@ class RWConfig(PretrainedConfig):
|
||||
hidden_size=64,
|
||||
num_hidden_layers=None,
|
||||
num_attention_heads=None,
|
||||
num_ln_in_prallel_attention=None,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
@ -61,6 +61,7 @@ class RWConfig(PretrainedConfig):
|
||||
new_decoder_architecture=None,
|
||||
bias=False,
|
||||
parallel_attn=False,
|
||||
rope_theta=10_000.0,
|
||||
**kwargs,
|
||||
):
|
||||
if alibi:
|
||||
@ -71,6 +72,7 @@ class RWConfig(PretrainedConfig):
|
||||
self.model_type = model_type
|
||||
self.alibi = False
|
||||
self.rotary = True
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
# Backward compatibility with n_embed kwarg
|
||||
@ -87,6 +89,7 @@ class RWConfig(PretrainedConfig):
|
||||
else kwargs.pop("n_head", 8)
|
||||
)
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.num_ln_in_parallel_attention = num_ln_in_prallel_attention
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.hidden_dropout = hidden_dropout
|
||||
@ -128,9 +131,10 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.num_heads_kv = config.n_head_kv
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
config=config, dim=self.head_size, base=self.rope_theta, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
@ -240,6 +244,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
self.num_groups = num_groups
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
@ -253,7 +258,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
|
||||
if process_group.size() > self.num_groups:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
)
|
||||
if self.num_groups % process_group.size() != 0:
|
||||
raise NotImplementedError(
|
||||
@ -455,29 +460,60 @@ class FlashRWLayer(nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
if self.post_attention_layernorm is not None:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
|
||||
return mlp_output, residual
|
||||
|
||||
class FlashRWLayerNorm(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.num_ln = config.num_ln_in_parallel_attn
|
||||
|
||||
if self.num_ln == 1:
|
||||
self.input_ln = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
elif self.num_ln == 2:
|
||||
self.ln_attn = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_attn",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.ln_mlp = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_mlp",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Number of layer norms can either be 1 or 2.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
):
|
||||
if self.num_ln == 1:
|
||||
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
|
||||
return ln_hidden_states, ln_hidden_states, residual
|
||||
elif self.num_ln == 2:
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(residual)
|
||||
return ln_attn, ln_mlp, residual
|
||||
|
||||
|
||||
class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
self.ln_attn = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_attn",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.ln_mlp = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_mlp",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
||||
|
||||
self.self_attention = FlashRWLargeAttention(
|
||||
config,
|
||||
@ -503,8 +539,8 @@ class FlashRWLargeLayer(nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(residual)
|
||||
# Layer norm.
|
||||
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
|
||||
|
||||
# Self attention.
|
||||
attn_output = self.self_attention(
|
||||
|
Loading…
Reference in New Issue
Block a user