mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: setting the rotary base from the config for the grouped query models.
This commit is contained in:
parent
f2b3d8d7ed
commit
a24bf62368
@ -1,19 +1,19 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import flash_attn, paged_attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
|
||||||
TensorParallelColumnLinear,
|
|
||||||
TensorParallelEmbedding,
|
|
||||||
SpeculativeHead,
|
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,7 +134,10 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=self.rope_theta, device=weights.device
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=self.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
@ -247,7 +250,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=self.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
@ -469,6 +475,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
return mlp_output, residual
|
return mlp_output, residual
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
class FlashRWLayerNorm(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user