chore: removed repeating code.

This commit is contained in:
Nilabhra 2024-05-14 11:58:12 +04:00
parent e45ede2cde
commit 8ce8265966

View File

@ -6,21 +6,16 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import FastLayerNorm
FastLayerNorm, from text_generation_server.layers.rotary import PositionRotaryEmbedding
) from text_generation_server.utils import flash_attn, paged_attention
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
@ -520,162 +515,6 @@ class FlashRWLayerNorm(nn.Module):
return ln_attn, ln_mlp, residual return ln_attn, ln_mlp, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
elif self.num_ln == 2:
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
else:
raise ValueError("Number of layer norms can either be 1 or 2.")
def forward(
self,
hidden_states,
residual,
):
if self.num_ln == 1:
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
return ln_hidden_states, ln_hidden_states, residual
elif self.num_ln == 2:
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
return ln_attn, ln_mlp, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
elif self.num_ln == 2:
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
else:
raise ValueError("Number of layer norms can either be 1 or 2.")
def forward(
self,
hidden_states,
residual,
):
if self.num_ln == 1:
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
return ln_hidden_states, ln_hidden_states, residual
elif self.num_ln == 2:
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
return ln_attn, ln_mlp, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
elif self.num_ln == 2:
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
else:
raise ValueError("Number of layer norms can either be 1 or 2.")
def forward(
self,
hidden_states,
residual,
):
if self.num_ln == 1:
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
return ln_hidden_states, ln_hidden_states, residual
elif self.num_ln == 2:
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
return ln_attn, ln_mlp, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
elif self.num_ln == 2:
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
else:
raise ValueError("Number of layer norms can either be 1 or 2.")
def forward(
self,
hidden_states,
residual,
):
if self.num_ln == 1:
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
return ln_hidden_states, ln_hidden_states, residual
elif self.num_ln == 2:
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
return ln_attn, ln_mlp, residual
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()