mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add: support for falcon-10B architecture.
This commit is contained in:
parent
762dbf3f19
commit
f2b3d8d7ed
@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import attention
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -48,6 +47,7 @@ class RWConfig(PretrainedConfig):
|
|||||||
hidden_size=64,
|
hidden_size=64,
|
||||||
num_hidden_layers=None,
|
num_hidden_layers=None,
|
||||||
num_attention_heads=None,
|
num_attention_heads=None,
|
||||||
|
num_ln_in_prallel_attention=None,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
@ -61,6 +61,7 @@ class RWConfig(PretrainedConfig):
|
|||||||
new_decoder_architecture=None,
|
new_decoder_architecture=None,
|
||||||
bias=False,
|
bias=False,
|
||||||
parallel_attn=False,
|
parallel_attn=False,
|
||||||
|
rope_theta=10_000.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if alibi:
|
if alibi:
|
||||||
@ -71,6 +72,7 @@ class RWConfig(PretrainedConfig):
|
|||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.alibi = False
|
self.alibi = False
|
||||||
self.rotary = True
|
self.rotary = True
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
# Backward compatibility with n_embed kwarg
|
# Backward compatibility with n_embed kwarg
|
||||||
@ -87,6 +89,7 @@ class RWConfig(PretrainedConfig):
|
|||||||
else kwargs.pop("n_head", 8)
|
else kwargs.pop("n_head", 8)
|
||||||
)
|
)
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.num_ln_in_parallel_attention = num_ln_in_prallel_attention
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.hidden_dropout = hidden_dropout
|
self.hidden_dropout = hidden_dropout
|
||||||
@ -128,9 +131,10 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.num_heads_kv = config.n_head_kv
|
self.num_heads_kv = config.n_head_kv
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
config=config, dim=self.head_size, base=self.rope_theta, device=weights.device
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
@ -240,6 +244,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
self.num_groups = num_groups
|
self.num_groups = num_groups
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||||
@ -253,7 +258,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
if process_group.size() > self.num_groups:
|
if process_group.size() > self.num_groups:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
"Tensor Parallelism is not implemented for world_size > n groups"
|
||||||
)
|
)
|
||||||
if self.num_groups % process_group.size() != 0:
|
if self.num_groups % process_group.size() != 0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -455,6 +460,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.post_attention_layernorm is not None:
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual
|
hidden_states, residual
|
||||||
)
|
)
|
||||||
@ -463,11 +469,18 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
return mlp_output, residual
|
return mlp_output, residual
|
||||||
|
|
||||||
|
class FlashRWLayerNorm(nn.Module):
|
||||||
class FlashRWLargeLayer(nn.Module):
|
def __init__(self, config, prefix, weights):
|
||||||
def __init__(self, layer_id, config, weights):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.h.{layer_id}"
|
self.num_ln = config.num_ln_in_parallel_attn
|
||||||
|
|
||||||
|
if self.num_ln == 1:
|
||||||
|
self.input_ln = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
elif self.num_ln == 2:
|
||||||
self.ln_attn = FastLayerNorm.load(
|
self.ln_attn = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.ln_attn",
|
prefix=f"{prefix}.ln_attn",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -478,6 +491,29 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Number of layer norms can either be 1 or 2.")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
):
|
||||||
|
if self.num_ln == 1:
|
||||||
|
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
|
||||||
|
return ln_hidden_states, ln_hidden_states, residual
|
||||||
|
elif self.num_ln == 2:
|
||||||
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
|
ln_mlp, _ = self.ln_mlp(residual)
|
||||||
|
return ln_attn, ln_mlp, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FlashRWLargeLayer(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"transformer.h.{layer_id}"
|
||||||
|
|
||||||
|
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
||||||
|
|
||||||
self.self_attention = FlashRWLargeAttention(
|
self.self_attention = FlashRWLargeAttention(
|
||||||
config,
|
config,
|
||||||
@ -503,8 +539,8 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
# Layer norm.
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
|
||||||
|
|
||||||
# Self attention.
|
# Self attention.
|
||||||
attn_output = self.self_attention(
|
attn_output = self.self_attention(
|
||||||
|
Loading…
Reference in New Issue
Block a user