diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index dd7b043a..fa463a19 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,21 +6,16 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention from text_generation_server.layers import ( - TensorParallelRowLinear, + SpeculativeHead, TensorParallelColumnLinear, TensorParallelEmbedding, - SpeculativeHead, + TensorParallelRowLinear, get_linear, ) -from text_generation_server.layers.layernorm import ( - FastLayerNorm, -) -from text_generation_server.layers.rotary import ( - PositionRotaryEmbedding, -) +from text_generation_server.layers.layernorm import FastLayerNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.utils import flash_attn, paged_attention def load_row(config, prefix: str, weights, bias: bool): @@ -520,162 +515,6 @@ class FlashRWLayerNorm(nn.Module): return ln_attn, ln_mlp, residual -class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_ln = config.num_ln_in_parallel_attn - - if self.num_ln == 1: - self.input_ln = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - elif self.num_ln == 2: - self.ln_attn = FastLayerNorm.load( - prefix=f"{prefix}.ln_attn", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.ln_mlp = FastLayerNorm.load( - prefix=f"{prefix}.ln_mlp", - weights=weights, - eps=config.layer_norm_epsilon, - ) - else: - raise ValueError("Number of layer norms can either be 1 or 2.") - - def forward( - self, - hidden_states, - residual, - ): - if self.num_ln == 1: - ln_hidden_states, residual = self.input_ln(hidden_states, residual) - return ln_hidden_states, ln_hidden_states, residual - elif self.num_ln == 2: - ln_attn, residual = self.ln_attn(hidden_states, residual) - ln_mlp, _ = self.ln_mlp(residual) - return ln_attn, ln_mlp, residual - - -class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_ln = config.num_ln_in_parallel_attn - - if self.num_ln == 1: - self.input_ln = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - elif self.num_ln == 2: - self.ln_attn = FastLayerNorm.load( - prefix=f"{prefix}.ln_attn", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.ln_mlp = FastLayerNorm.load( - prefix=f"{prefix}.ln_mlp", - weights=weights, - eps=config.layer_norm_epsilon, - ) - else: - raise ValueError("Number of layer norms can either be 1 or 2.") - - def forward( - self, - hidden_states, - residual, - ): - if self.num_ln == 1: - ln_hidden_states, residual = self.input_ln(hidden_states, residual) - return ln_hidden_states, ln_hidden_states, residual - elif self.num_ln == 2: - ln_attn, residual = self.ln_attn(hidden_states, residual) - ln_mlp, _ = self.ln_mlp(residual) - return ln_attn, ln_mlp, residual - - -class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_ln = config.num_ln_in_parallel_attn - - if self.num_ln == 1: - self.input_ln = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - elif self.num_ln == 2: - self.ln_attn = FastLayerNorm.load( - prefix=f"{prefix}.ln_attn", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.ln_mlp = FastLayerNorm.load( - prefix=f"{prefix}.ln_mlp", - weights=weights, - eps=config.layer_norm_epsilon, - ) - else: - raise ValueError("Number of layer norms can either be 1 or 2.") - - def forward( - self, - hidden_states, - residual, - ): - if self.num_ln == 1: - ln_hidden_states, residual = self.input_ln(hidden_states, residual) - return ln_hidden_states, ln_hidden_states, residual - elif self.num_ln == 2: - ln_attn, residual = self.ln_attn(hidden_states, residual) - ln_mlp, _ = self.ln_mlp(residual) - return ln_attn, ln_mlp, residual - - -class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_ln = config.num_ln_in_parallel_attn - - if self.num_ln == 1: - self.input_ln = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - elif self.num_ln == 2: - self.ln_attn = FastLayerNorm.load( - prefix=f"{prefix}.ln_attn", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.ln_mlp = FastLayerNorm.load( - prefix=f"{prefix}.ln_mlp", - weights=weights, - eps=config.layer_norm_epsilon, - ) - else: - raise ValueError("Number of layer norms can either be 1 or 2.") - - def forward( - self, - hidden_states, - residual, - ): - if self.num_ln == 1: - ln_hidden_states, residual = self.input_ln(hidden_states, residual) - return ln_hidden_states, ln_hidden_states, residual - elif self.num_ln == 2: - ln_attn, residual = self.ln_attn(hidden_states, residual) - ln_mlp, _ = self.ln_mlp(residual) - return ln_attn, ln_mlp, residual - - class FlashRWLargeLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__()