initial commit of mlp_speculator support (draft)

This commit is contained in:
Joshua Rosenkranz 2024-05-02 10:18:42 -04:00
parent 6073ece4fc
commit 43a2a0ca5e

View File

@ -1,4 +1,7 @@
import json
import os import os
from pathlib import Path
import torch import torch
import torch.distributed import torch.distributed
@ -439,6 +442,197 @@ class ResBlock(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x + self.act(self.linear(x)) return x + self.act(self.linear(x))
class LayerNormParameterized(nn.Module):
"""
A generalized LayerNorm implementation. With all optional arguments set to True, equivalent to nn.LayerNorm up to epsilon stabilization term
(this class divides inputs by min(norm, eps), while nn.LayerNorm divides by norm + eps).
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
elementwise_scale : bool
Include a learned scaling term after normalization?
elementwise_shift : bool
Include a learned bias term after normalization?
use_mean : bool
Recenter inputs around zero before normalizing, or just rescale?
"""
def __init__(
self,
normalized_shape,
eps=1e-06,
elementwise_scale=True,
elementwise_shift=False,
use_mean=False,
use_high_precision_pow=False,
):
super(LayerNormParameterized, self).__init__()
self.normalized_shape = normalized_shape
self.eps = eps
self.elementwise_scale = elementwise_scale
self.elementwise_shift = elementwise_shift
self.use_mean = use_mean
self.use_high_precision_pow = use_high_precision_pow
if self.elementwise_scale:
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
if self.elementwise_shift:
self.bias = nn.Parameter(torch.empty(self.normalized_shape))
def reset_parameters(self):
if self.elementwise_scale:
self.weight.data.fill_(1)
if self.elementwise_shift:
self.bias.data.zero_()
def forward(self, x):
if self.use_mean:
x = x - x.mean(-1, keepdim=True)
xf = x
if self.use_high_precision_pow:
xf = x.float()
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
if self.elementwise_scale:
x = self.weight * x
if self.elementwise_shift:
x = x + self.bias
return x
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, emb, proj, head, ln):
super().__init__()
self.config = config
self.n_predict = config.n_predict
self.emb_dim = config.emb_dim
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
self.inner_dim = inner_dim
self.config = config.vocab_size
self.emb = emb
self.proj = proj
self.head = head
self.ln = ln
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight ** 2)
self.activation = nn.GELU()
def forward(self, state: torch.Tensor, ind: torch.Tensor, top_k_tokens_per_head: Optional[List[int]], num_candidates: int = 1):
if top_k_tokens_per_head is None:
top_k_tokens_per_head = self.config.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
b = state.size(0)
out = torch.empty(b, 1, 0, device=state.device).int() # b k h
log_probs = torch.zeros(b, 1, device=state.device) # b k
all_probs = torch.empty(b, 1, 0, self.vsize, device=state.device) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
_probs = F.log_softmax(self.head[i](state), dim=2) # b k v
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=2) # b k k'
# Update candidate set with new predictions
out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h
out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1
out = out.view(b, -1, i + 1) # b kk' h+1
# Update distribution set with new logits
all_probs = torch.cat([all_probs, _probs.exp().unsqueeze(2)], dim=2) # b k h+1 v
all_probs = all_probs.repeat(1, top_k_tokens_per_head[i], 1, 1) # b kk' h+1 v
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' d
state = state.reshape(b, -1, state.size(3)) # b kk' d
ind = preds.view(b, -1) # b kk'
log_probs = log_probs.unsqueeze(2).expand(b, -1, top_k_tokens_per_head[i]) # b k k'
log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
# Take only top n best guesses
best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
return all_probs.gather(
1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
) # b n h v
def load(self, config, prefix, weights):
self.emb = nn.ModuleList(
[nn.Embedding(config.vocab_size, config.inner_dim) for _ in range(config.n_predict)]
)
self.proj = nn.ModuleList(
[
nn.Linear((config.emb_dim if i == 0 else config.inner_dim), config.inner_dim, bias=False)
for i in range(config.n_predict)
]
)
self.head = nn.ModuleList(
[nn.Linear(config.inner_dim, config.vocab_size, bias=False) for _ in range(config.n_predict)]
)
self.ln = nn.ModuleList(
[
LayerNormParameterized(
config.inner_dim, elementwise_shift=True, elementwise_scale=True
)
for _ in range(config.n_predict)
]
)
for i in range(config.n_predict):
self.emb[i].weight.data.copy_(weights.get_tensor(f"{prefix}.emb.{i}.weight"))
self.proj[i].weight.data.copy_(weights.get_tensor(f"{prefix}.proj.{i}.weight"))
self.ln[i].weight.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.weight"))
self.ln[i].bias.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.bias"))
self.head[i].weight.data.copy_(weights.get_tensor(f"{prefix}.head.{i}.weight"))
class MLPSpeculatorHeadV1(nn.Module):
def __init__(self, lm_head, mlp_speculator):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
speculative_logits = self.mlp_speculator(input)
return logits, speculative_logits
@staticmethod
def load(speculator_config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
speculator_path = speculator_config.use_speculator
filename = str(Path(speculator_path) / "*.safetensors")
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
mlp_speculator = MLPSpeculatorModel.load(speculator_config, prefix, weights)
lm_head = TensorParallelHead.load(speculator_config, prefix, weights)
return MLPSpeculatorHeadV1(lm_head, mlp_speculator)
class MedusaModel(torch.nn.Module): class MedusaModel(torch.nn.Module):
def __init__(self, config, medusa_config, weights): def __init__(self, config, medusa_config, weights):
@ -606,24 +800,35 @@ class MedusaHeadV2(nn.Module):
class SpeculativeHead(nn.Module): class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa): def __init__(self, lm_head, speculator):
super().__init__() super().__init__()
self.head = lm_head self.head = lm_head
self.medusa = medusa self.speculator = speculator
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
use_medusa = config.use_medusa use_speculator = config.use_speculator
if use_medusa: if use_speculator:
speculator_config = str(Path(use_speculator) / "config.json")
with open(speculator_config, "r") as f:
speculator_config = json.load(f)
lm_head = None lm_head = None
try:
medusa = MedusaHeadV1.load(config, prefix, weights) architecture = speculator_config["architectures"][0]
except:
medusa = MedusaHeadV2(config, prefix, weights) if architecture == "MLPSpeculatorPreTrainedModel":
speculator_config.use_speculator = config.use_speculator
speculator = MLPSpeculatorHeadV1.load(speculator_config, "speculator", weights)
else: # not sure what medusa name is...
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except:
speculator = MedusaHeadV2(config, prefix, weights)
else: else:
lm_head = TensorParallelHead.load(config, prefix, weights) lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None speculator = None
return SpeculativeHead(lm_head, medusa) return SpeculativeHead(lm_head, speculator)
def forward( def forward(
self, input: torch.Tensor self, input: torch.Tensor