From d619666d236a3e1edef6424d7a7d55deae01bc67 Mon Sep 17 00:00:00 2001 From: Nilabhra Date: Tue, 14 May 2024 10:14:18 +0400 Subject: [PATCH] fix: setting the rotary base from the config for the grouped query models. --- .../models/custom_modeling/flash_rw_modeling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 921540b8..b0b9d500 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,12 +6,14 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils import flash_attn, paged_attention from text_generation_server.utils.layers import ( - TensorParallelRowLinear, + FastLayerNorm, + PositionRotaryEmbedding, + SpeculativeHead, TensorParallelColumnLinear, TensorParallelEmbedding, - SpeculativeHead, + TensorParallelRowLinear, get_linear, ) from text_generation_server.layers.layernorm import ( @@ -479,6 +481,7 @@ class FlashRWLayer(nn.Module): return mlp_output, residual + class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix, weights): super().__init__() @@ -678,7 +681,7 @@ class FlashRWLargeLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" - + self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.self_attention = FlashRWLargeAttention(