mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: setting the rotary base from the config for the grouped query models.
This commit is contained in:
parent
22c005fac3
commit
dcd2b4425c
@ -6,12 +6,14 @@ from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils import flash_attn, paged_attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import (
|
||||
@ -138,7 +140,10 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=self.rope_theta, device=weights.device
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=self.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
@ -476,6 +481,7 @@ class FlashRWLayer(nn.Module):
|
||||
|
||||
return mlp_output, residual
|
||||
|
||||
|
||||
class FlashRWLayerNorm(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user