diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 4db4a9a0..205fa393 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,19 +1,19 @@ +from typing import List, Optional, Tuple + import torch import torch.distributed - from torch import nn -from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig -from typing import Optional, List, Tuple +from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils import flash_attn, paged_attention from text_generation_server.utils.layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - SpeculativeHead, FastLayerNorm, PositionRotaryEmbedding, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, get_linear, ) @@ -134,7 +134,10 @@ class FlashRWAttention(torch.nn.Module): self.rope_theta = config.rope_theta self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=self.rope_theta, device=weights.device + config=config, + dim=self.head_size, + base=self.rope_theta, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5) @@ -247,7 +250,10 @@ class FlashRWLargeAttention(torch.nn.Module): self.rope_theta = config.rope_theta self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=10000.0, device=weights.device + config=config, + dim=self.head_size, + base=self.rope_theta, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5) @@ -469,6 +475,7 @@ class FlashRWLayer(nn.Module): return mlp_output, residual + class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix, weights): super().__init__() @@ -512,7 +519,7 @@ class FlashRWLargeLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" - + self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.self_attention = FlashRWLargeAttention(