added a bunch of cleanup based on comments in PR; removed conditionals from LayerNormParameterized and renamed to MLPSpeculatorLayerNorm; now using modules for tensor-parallel (this is work in progress - looking into if this is right approach); fixed issue with getting medusa model; fixed for more efficient loading

This commit is contained in:
Joshua Rosenkranz 2024-05-03 10:02:11 -04:00
parent 43a2a0ca5e
commit b9d8af694b

View File

@ -442,69 +442,46 @@ class ResBlock(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x + self.act(self.linear(x)) return x + self.act(self.linear(x))
class LayerNormParameterized(nn.Module): class MLPSpeculatorLayerNorm(nn.Module):
""" """
A generalized LayerNorm implementation. With all optional arguments set to True, equivalent to nn.LayerNorm up to epsilon stabilization term A L2 normalization implementation
(this class divides inputs by min(norm, eps), while nn.LayerNorm divides by norm + eps).
... ...
Args Args
---- ----
normalized_shape : int normalized_shape : int
Dimensionality of input data (size of final tensor axis) Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
elementwise_scale : bool
Include a learned scaling term after normalization?
elementwise_shift : bool
Include a learned bias term after normalization?
use_mean : bool
Recenter inputs around zero before normalizing, or just rescale?
""" """
def __init__( def __init__(
self, self,
normalized_shape, normalized_shape,
elementwise_scale_weight: torch.Tensor,
elementwise_shift_bias: torch.Tensor,
eps=1e-06, eps=1e-06,
elementwise_scale=True,
elementwise_shift=False,
use_mean=False,
use_high_precision_pow=False,
): ):
super(LayerNormParameterized, self).__init__() super(MLPSpeculatorLayerNorm, self).__init__()
self.normalized_shape = normalized_shape self.normalized_shape = normalized_shape
self.weight = nn.Parameter(elementwise_scale_weight)
self.bias = nn.Parameter(elementwise_shift_bias)
self.eps = eps self.eps = eps
self.elementwise_scale = elementwise_scale
self.elementwise_shift = elementwise_shift
self.use_mean = use_mean
self.use_high_precision_pow = use_high_precision_pow
if self.elementwise_scale:
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
if self.elementwise_shift:
self.bias = nn.Parameter(torch.empty(self.normalized_shape))
def reset_parameters(self):
if self.elementwise_scale:
self.weight.data.fill_(1)
if self.elementwise_shift:
self.bias.data.zero_()
def forward(self, x): def forward(self, x):
if self.use_mean:
x = x - x.mean(-1, keepdim=True)
xf = x xf = x
if self.use_high_precision_pow:
xf = x.float()
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x) x = xf.type_as(x)
if self.elementwise_scale:
x = self.weight * x x = self.weight * x
if self.elementwise_shift:
x = x + self.bias x = x + self.bias
return x return x
class MLPSpeculatorModel(torch.nn.Module): class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, emb, proj, head, ln): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.n_predict = config.n_predict self.n_predict = config.n_predict
@ -512,10 +489,30 @@ class MLPSpeculatorModel(torch.nn.Module):
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
self.inner_dim = inner_dim self.inner_dim = inner_dim
self.config = config.vocab_size self.config = config.vocab_size
self.emb = emb self.emb = nn.ModuleList(
self.proj = proj [TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)]
self.head = head )
self.ln = ln self.proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)],
weights=weights,
bias=False,
dim=0
)
self.head = nn.ModuleList(
[TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
config.inner_dim,
elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"),
elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias")
)
for i in range(config.n_predict)
]
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight ** 2) self.emb_weight = math.sqrt(1 - self.state_weight ** 2)
@ -565,36 +562,8 @@ class MLPSpeculatorModel(torch.nn.Module):
1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize) 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
) # b n h v ) # b n h v
def load(self, config, prefix, weights):
self.emb = nn.ModuleList(
[nn.Embedding(config.vocab_size, config.inner_dim) for _ in range(config.n_predict)]
)
self.proj = nn.ModuleList(
[
nn.Linear((config.emb_dim if i == 0 else config.inner_dim), config.inner_dim, bias=False)
for i in range(config.n_predict)
]
)
self.head = nn.ModuleList(
[nn.Linear(config.inner_dim, config.vocab_size, bias=False) for _ in range(config.n_predict)]
)
self.ln = nn.ModuleList(
[
LayerNormParameterized(
config.inner_dim, elementwise_shift=True, elementwise_scale=True
)
for _ in range(config.n_predict)
]
)
for i in range(config.n_predict):
self.emb[i].weight.data.copy_(weights.get_tensor(f"{prefix}.emb.{i}.weight"))
self.proj[i].weight.data.copy_(weights.get_tensor(f"{prefix}.proj.{i}.weight"))
self.ln[i].weight.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.weight"))
self.ln[i].bias.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.bias"))
self.head[i].weight.data.copy_(weights.get_tensor(f"{prefix}.head.{i}.weight"))
class MLPSpeculatorHead(nn.Module):
class MLPSpeculatorHeadV1(nn.Module):
def __init__(self, lm_head, mlp_speculator): def __init__(self, lm_head, mlp_speculator):
super().__init__() super().__init__()
self.lm_head = lm_head self.lm_head = lm_head
@ -629,9 +598,9 @@ class MLPSpeculatorHeadV1(nn.Module):
) )
routing[k] = filename routing[k] = filename
mlp_speculator = MLPSpeculatorModel.load(speculator_config, prefix, weights) mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights)
lm_head = TensorParallelHead.load(speculator_config, prefix, weights) lm_head = TensorParallelHead.load(speculator_config, prefix, weights)
return MLPSpeculatorHeadV1(lm_head, mlp_speculator) return MLPSpeculatorHead(lm_head, mlp_speculator)
class MedusaModel(torch.nn.Module): class MedusaModel(torch.nn.Module):
@ -815,12 +784,17 @@ class SpeculativeHead(nn.Module):
speculator_config = json.load(f) speculator_config = json.load(f)
lm_head = None lm_head = None
# currently medusa does not have an architecture specified, so try-except for now
# this should really be handled in a better way though (maybe the classname can be part of the config)
try:
architecture = speculator_config["architectures"][0] architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel": if architecture == "MLPSpeculatorPreTrainedModel":
speculator_config.use_speculator = config.use_speculator speculator_config.use_speculator = config.use_speculator
speculator = MLPSpeculatorHeadV1.load(speculator_config, "speculator", weights) speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights)
else: # not sure what medusa name is... else:
speculator = None
except KeyError:
try: try:
speculator = MedusaHeadV1.load(config, prefix, weights) speculator = MedusaHeadV1.load(config, prefix, weights)
except: except: