mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
added a bunch of cleanup based on comments in PR; removed conditionals from LayerNormParameterized and renamed to MLPSpeculatorLayerNorm; now using modules for tensor-parallel (this is work in progress - looking into if this is right approach); fixed issue with getting medusa model; fixed for more efficient loading
This commit is contained in:
parent
6e5c19ec44
commit
453e91f755
@ -442,69 +442,46 @@ class ResBlock(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.act(self.linear(x))
|
return x + self.act(self.linear(x))
|
||||||
|
|
||||||
class LayerNormParameterized(nn.Module):
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
A generalized LayerNorm implementation. With all optional arguments set to True, equivalent to nn.LayerNorm up to epsilon stabilization term
|
A L2 normalization implementation
|
||||||
(this class divides inputs by min(norm, eps), while nn.LayerNorm divides by norm + eps).
|
|
||||||
...
|
...
|
||||||
Args
|
Args
|
||||||
----
|
----
|
||||||
normalized_shape : int
|
normalized_shape : int
|
||||||
Dimensionality of input data (size of final tensor axis)
|
Dimensionality of input data (size of final tensor axis)
|
||||||
|
elementwise_scale_weight : torch.Tensor
|
||||||
|
learned scaling term after normalization?
|
||||||
|
elementwise_shift_bias : torch.Tensor
|
||||||
|
learned bias term after normalization?
|
||||||
eps : float
|
eps : float
|
||||||
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
|
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
|
||||||
elementwise_scale : bool
|
|
||||||
Include a learned scaling term after normalization?
|
|
||||||
elementwise_shift : bool
|
|
||||||
Include a learned bias term after normalization?
|
|
||||||
use_mean : bool
|
|
||||||
Recenter inputs around zero before normalizing, or just rescale?
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
normalized_shape,
|
normalized_shape,
|
||||||
|
elementwise_scale_weight: torch.Tensor,
|
||||||
|
elementwise_shift_bias: torch.Tensor,
|
||||||
eps=1e-06,
|
eps=1e-06,
|
||||||
elementwise_scale=True,
|
|
||||||
elementwise_shift=False,
|
|
||||||
use_mean=False,
|
|
||||||
use_high_precision_pow=False,
|
|
||||||
):
|
):
|
||||||
super(LayerNormParameterized, self).__init__()
|
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||||
self.normalized_shape = normalized_shape
|
self.normalized_shape = normalized_shape
|
||||||
|
self.weight = nn.Parameter(elementwise_scale_weight)
|
||||||
|
self.bias = nn.Parameter(elementwise_shift_bias)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.elementwise_scale = elementwise_scale
|
|
||||||
self.elementwise_shift = elementwise_shift
|
|
||||||
self.use_mean = use_mean
|
|
||||||
self.use_high_precision_pow = use_high_precision_pow
|
|
||||||
|
|
||||||
if self.elementwise_scale:
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
|
|
||||||
if self.elementwise_shift:
|
|
||||||
self.bias = nn.Parameter(torch.empty(self.normalized_shape))
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if self.elementwise_scale:
|
|
||||||
self.weight.data.fill_(1)
|
|
||||||
if self.elementwise_shift:
|
|
||||||
self.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.use_mean:
|
|
||||||
x = x - x.mean(-1, keepdim=True)
|
|
||||||
xf = x
|
xf = x
|
||||||
if self.use_high_precision_pow:
|
|
||||||
xf = x.float()
|
|
||||||
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
x = xf.type_as(x)
|
x = xf.type_as(x)
|
||||||
if self.elementwise_scale:
|
|
||||||
x = self.weight * x
|
x = self.weight * x
|
||||||
if self.elementwise_shift:
|
|
||||||
x = x + self.bias
|
x = x + self.bias
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class MLPSpeculatorModel(torch.nn.Module):
|
class MLPSpeculatorModel(torch.nn.Module):
|
||||||
def __init__(self, config, emb, proj, head, ln):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.n_predict = config.n_predict
|
self.n_predict = config.n_predict
|
||||||
@ -512,10 +489,30 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
|
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
|
||||||
self.inner_dim = inner_dim
|
self.inner_dim = inner_dim
|
||||||
self.config = config.vocab_size
|
self.config = config.vocab_size
|
||||||
self.emb = emb
|
self.emb = nn.ModuleList(
|
||||||
self.proj = proj
|
[TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)]
|
||||||
self.head = head
|
)
|
||||||
self.ln = ln
|
self.proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)],
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
dim=0
|
||||||
|
)
|
||||||
|
self.head = nn.ModuleList(
|
||||||
|
[TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)]
|
||||||
|
)
|
||||||
|
self.ln = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MLPSpeculatorLayerNorm(
|
||||||
|
config.inner_dim,
|
||||||
|
elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"),
|
||||||
|
elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias")
|
||||||
|
)
|
||||||
|
for i in range(config.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
||||||
self.emb_weight = math.sqrt(1 - self.state_weight ** 2)
|
self.emb_weight = math.sqrt(1 - self.state_weight ** 2)
|
||||||
@ -565,36 +562,8 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
||||||
) # b n h v
|
) # b n h v
|
||||||
|
|
||||||
def load(self, config, prefix, weights):
|
|
||||||
self.emb = nn.ModuleList(
|
|
||||||
[nn.Embedding(config.vocab_size, config.inner_dim) for _ in range(config.n_predict)]
|
|
||||||
)
|
|
||||||
self.proj = nn.ModuleList(
|
|
||||||
[
|
|
||||||
nn.Linear((config.emb_dim if i == 0 else config.inner_dim), config.inner_dim, bias=False)
|
|
||||||
for i in range(config.n_predict)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.head = nn.ModuleList(
|
|
||||||
[nn.Linear(config.inner_dim, config.vocab_size, bias=False) for _ in range(config.n_predict)]
|
|
||||||
)
|
|
||||||
self.ln = nn.ModuleList(
|
|
||||||
[
|
|
||||||
LayerNormParameterized(
|
|
||||||
config.inner_dim, elementwise_shift=True, elementwise_scale=True
|
|
||||||
)
|
|
||||||
for _ in range(config.n_predict)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for i in range(config.n_predict):
|
|
||||||
self.emb[i].weight.data.copy_(weights.get_tensor(f"{prefix}.emb.{i}.weight"))
|
|
||||||
self.proj[i].weight.data.copy_(weights.get_tensor(f"{prefix}.proj.{i}.weight"))
|
|
||||||
self.ln[i].weight.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.weight"))
|
|
||||||
self.ln[i].bias.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.bias"))
|
|
||||||
self.head[i].weight.data.copy_(weights.get_tensor(f"{prefix}.head.{i}.weight"))
|
|
||||||
|
|
||||||
|
class MLPSpeculatorHead(nn.Module):
|
||||||
class MLPSpeculatorHeadV1(nn.Module):
|
|
||||||
def __init__(self, lm_head, mlp_speculator):
|
def __init__(self, lm_head, mlp_speculator):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lm_head = lm_head
|
self.lm_head = lm_head
|
||||||
@ -629,9 +598,9 @@ class MLPSpeculatorHeadV1(nn.Module):
|
|||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
|
||||||
mlp_speculator = MLPSpeculatorModel.load(speculator_config, prefix, weights)
|
mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights)
|
||||||
lm_head = TensorParallelHead.load(speculator_config, prefix, weights)
|
lm_head = TensorParallelHead.load(speculator_config, prefix, weights)
|
||||||
return MLPSpeculatorHeadV1(lm_head, mlp_speculator)
|
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
||||||
|
|
||||||
|
|
||||||
class MedusaModel(torch.nn.Module):
|
class MedusaModel(torch.nn.Module):
|
||||||
@ -815,12 +784,17 @@ class SpeculativeHead(nn.Module):
|
|||||||
speculator_config = json.load(f)
|
speculator_config = json.load(f)
|
||||||
lm_head = None
|
lm_head = None
|
||||||
|
|
||||||
|
# currently medusa does not have an architecture specified, so try-except for now
|
||||||
|
# this should really be handled in a better way though (maybe the classname can be part of the config)
|
||||||
|
try:
|
||||||
architecture = speculator_config["architectures"][0]
|
architecture = speculator_config["architectures"][0]
|
||||||
|
|
||||||
if architecture == "MLPSpeculatorPreTrainedModel":
|
if architecture == "MLPSpeculatorPreTrainedModel":
|
||||||
speculator_config.use_speculator = config.use_speculator
|
speculator_config.use_speculator = config.use_speculator
|
||||||
speculator = MLPSpeculatorHeadV1.load(speculator_config, "speculator", weights)
|
speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights)
|
||||||
else: # not sure what medusa name is...
|
else:
|
||||||
|
speculator = None
|
||||||
|
except KeyError:
|
||||||
try:
|
try:
|
||||||
speculator = MedusaHeadV1.load(config, prefix, weights)
|
speculator = MedusaHeadV1.load(config, prefix, weights)
|
||||||
except:
|
except:
|
||||||
|
Loading…
Reference in New Issue
Block a user