Adding scaling support + optimize some ops.

This commit is contained in:
Nicolas Patry 2024-08-29 17:31:41 +02:00
parent 09a1de5cd1
commit 9f036684ef
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -45,6 +45,16 @@ class MLPSpeculatorLayerNorm(nn.Module):
return x return x
INV_SQRT2 = 2**-0.5
def simple_norm(x: torch.Tensor, eps=1e-06):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
x = xf.type_as(x)
return x * INV_SQRT2
class MLPSpeculatorModelTied(torch.nn.Module): class MLPSpeculatorModelTied(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
@ -74,12 +84,14 @@ class MLPSpeculatorModelTied(torch.nn.Module):
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU() self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"] self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward( def forward(
self, self,
@ -102,7 +114,7 @@ class MLPSpeculatorModelTied(torch.nn.Module):
for i in range(self.n_predict): for i in range(self.n_predict):
# Project and predict # Project and predict
z = self.emb(ind) z = self.emb(ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d # z = z.mul(self.emb_weight) # b k d
if i == 0: if i == 0:
state = self.proj0(state) * self.state_weight + z state = self.proj0(state) * self.state_weight + z
else: else:
@ -168,12 +180,14 @@ class MLPSpeculatorModel(torch.nn.Module):
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU() self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"] self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward( def forward(
self, self,
@ -196,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module):
for i in range(self.n_predict): for i in range(self.n_predict):
# Project and predict # Project and predict
z = self.emb[i](ind) z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d # z = z.mul(self.emb_weight) # b k d
state = self.proj[i](state) * self.state_weight + z state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
@ -219,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module):
class MLPSpeculatorHead(nn.Module): class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator): def __init__(self, lm_head, mlp_speculator, scale_input: bool):
super().__init__() super().__init__()
self.lm_head = lm_head self.lm_head = lm_head
self.mlp_speculator = mlp_speculator self.mlp_speculator = mlp_speculator
self.scale_input = scale_input
def forward( def forward(
self, input: torch.Tensor self, input: torch.Tensor
@ -233,6 +248,8 @@ class MLPSpeculatorHead(nn.Module):
return logits, None return logits, None
input_ids = logits.argmax(dim=-1) input_ids = logits.argmax(dim=-1)
if self.scale_input:
input = simple_norm(input)
speculative_logits = self.mlp_speculator(input, input_ids) speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits return logits, speculative_logits
@ -259,5 +276,7 @@ class MLPSpeculatorHead(nn.Module):
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights) mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
else: else:
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
scale_input = config.speculator_config.get("scale_input", False)
lm_head = TensorParallelHead.load(config, prefix, weights) lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator) return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)