fix: setting the rotary base from the config for the grouped query models.

This commit is contained in:
Nilabhra 2024-05-14 10:14:18 +04:00
parent f2b3d8d7ed
commit a24bf62368

View File

@ -1,19 +1,19 @@
from typing import List, Optional, Tuple
import torch
import torch.distributed
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils import flash_attn, paged_attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
FastLayerNorm,
PositionRotaryEmbedding,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
@ -134,7 +134,10 @@ class FlashRWAttention(torch.nn.Module):
self.rope_theta = config.rope_theta
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=self.rope_theta, device=weights.device
config=config,
dim=self.head_size,
base=self.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size ** (-0.5)
@ -247,7 +250,10 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rope_theta = config.rope_theta
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device
config=config,
dim=self.head_size,
base=self.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size ** (-0.5)
@ -469,6 +475,7 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()