mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
chore: removed repeating code.
This commit is contained in:
parent
e45ede2cde
commit
8ce8265966
@ -6,21 +6,16 @@ from torch import nn
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
|
||||||
from text_generation_server.utils.flash_attn import attention
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||||
FastLayerNorm,
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
)
|
from text_generation_server.utils import flash_attn, paged_attention
|
||||||
from text_generation_server.layers.rotary import (
|
|
||||||
PositionRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
@ -520,162 +515,6 @@ class FlashRWLayerNorm(nn.Module):
|
|||||||
return ln_attn, ln_mlp, residual
|
return ln_attn, ln_mlp, residual
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
|
||||||
|
|
||||||
if self.num_ln == 1:
|
|
||||||
self.input_ln = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.input_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
self.ln_attn = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_attn",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
self.ln_mlp = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_mlp",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Number of layer norms can either be 1 or 2.")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
):
|
|
||||||
if self.num_ln == 1:
|
|
||||||
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
|
|
||||||
return ln_hidden_states, ln_hidden_states, residual
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
|
||||||
return ln_attn, ln_mlp, residual
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
|
||||||
|
|
||||||
if self.num_ln == 1:
|
|
||||||
self.input_ln = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.input_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
self.ln_attn = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_attn",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
self.ln_mlp = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_mlp",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Number of layer norms can either be 1 or 2.")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
):
|
|
||||||
if self.num_ln == 1:
|
|
||||||
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
|
|
||||||
return ln_hidden_states, ln_hidden_states, residual
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
|
||||||
return ln_attn, ln_mlp, residual
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
|
||||||
|
|
||||||
if self.num_ln == 1:
|
|
||||||
self.input_ln = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.input_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
self.ln_attn = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_attn",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
self.ln_mlp = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_mlp",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Number of layer norms can either be 1 or 2.")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
):
|
|
||||||
if self.num_ln == 1:
|
|
||||||
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
|
|
||||||
return ln_hidden_states, ln_hidden_states, residual
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
|
||||||
return ln_attn, ln_mlp, residual
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
|
||||||
|
|
||||||
if self.num_ln == 1:
|
|
||||||
self.input_ln = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.input_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
self.ln_attn = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_attn",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
self.ln_mlp = FastLayerNorm.load(
|
|
||||||
prefix=f"{prefix}.ln_mlp",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Number of layer norms can either be 1 or 2.")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
):
|
|
||||||
if self.num_ln == 1:
|
|
||||||
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
|
|
||||||
return ln_hidden_states, ln_hidden_states, residual
|
|
||||||
elif self.num_ln == 2:
|
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
|
||||||
return ln_attn, ln_mlp, residual
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLargeLayer(nn.Module):
|
class FlashRWLargeLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user