fix: setting the rotary base from the config for the grouped query models.

This commit is contained in:
Nilabhra 2024-05-14 10:14:18 +04:00
parent 2fcbf5f3b9
commit d619666d23

View File

@ -6,12 +6,14 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import flash_attn, paged_attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, FastLayerNorm,
PositionRotaryEmbedding,
SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
@ -479,6 +481,7 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual return mlp_output, residual
class FlashRWLayerNorm(nn.Module): class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
@ -678,7 +681,7 @@ class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" prefix = f"transformer.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
self.self_attention = FlashRWLargeAttention( self.self_attention = FlashRWLargeAttention(