diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 7d252a0e..01b4ef10 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -7,8 +7,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention -from text_generation_server.layers import ( +from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -139,10 +138,7 @@ class FlashRWAttention(torch.nn.Module): self.rope_theta = config.rope_theta self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=self.rope_theta, - device=weights.device, + config=config, dim=self.head_size, base=self.rope_theta, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) @@ -480,6 +476,44 @@ class FlashRWLayer(nn.Module): return mlp_output, residual +class FlashRWLayerNorm(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.num_ln = config.num_ln_in_parallel_attn + + if self.num_ln == 1: + self.input_ln = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + elif self.num_ln == 2: + self.ln_attn = FastLayerNorm.load( + prefix=f"{prefix}.ln_attn", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.ln_mlp = FastLayerNorm.load( + prefix=f"{prefix}.ln_mlp", + weights=weights, + eps=config.layer_norm_epsilon, + ) + else: + raise ValueError("Number of layer norms can either be 1 or 2.") + + def forward( + self, + hidden_states, + residual, + ): + if self.num_ln == 1: + ln_hidden_states, residual = self.input_ln(hidden_states, residual) + return ln_hidden_states, ln_hidden_states, residual + elif self.num_ln == 2: + ln_attn, residual = self.ln_attn(hidden_states, residual) + ln_mlp, _ = self.ln_mlp(residual) + return ln_attn, ln_mlp, residual + class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix, weights): @@ -524,7 +558,7 @@ class FlashRWLargeLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" - + self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.self_attention = FlashRWLargeAttention(