mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Adding scaling support + optimize some ops.
This commit is contained in:
parent
09a1de5cd1
commit
9f036684ef
@ -45,6 +45,16 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
INV_SQRT2 = 2**-0.5
|
||||||
|
|
||||||
|
|
||||||
|
def simple_norm(x: torch.Tensor, eps=1e-06):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
return x * INV_SQRT2
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorModelTied(torch.nn.Module):
|
class MLPSpeculatorModelTied(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -74,12 +84,14 @@ class MLPSpeculatorModelTied(torch.nn.Module):
|
|||||||
|
|
||||||
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
# TODO
|
|
||||||
self.vsize = config.vocab_size
|
self.vsize = config.vocab_size
|
||||||
self.inner_dim = config.speculator_config["inner_dim"]
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
self.top_k_tokens_per_head = [1] * self.n_predict
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -102,7 +114,7 @@ class MLPSpeculatorModelTied(torch.nn.Module):
|
|||||||
for i in range(self.n_predict):
|
for i in range(self.n_predict):
|
||||||
# Project and predict
|
# Project and predict
|
||||||
z = self.emb(ind)
|
z = self.emb(ind)
|
||||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
if i == 0:
|
if i == 0:
|
||||||
state = self.proj0(state) * self.state_weight + z
|
state = self.proj0(state) * self.state_weight + z
|
||||||
else:
|
else:
|
||||||
@ -168,12 +180,14 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
|
|
||||||
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
# TODO
|
|
||||||
self.vsize = config.vocab_size
|
self.vsize = config.vocab_size
|
||||||
self.inner_dim = config.speculator_config["inner_dim"]
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
self.top_k_tokens_per_head = [1] * self.n_predict
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -196,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
for i in range(self.n_predict):
|
for i in range(self.n_predict):
|
||||||
# Project and predict
|
# Project and predict
|
||||||
z = self.emb[i](ind)
|
z = self.emb[i](ind)
|
||||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
state = self.proj[i](state) * self.state_weight + z
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
state = self.activation(self.ln[i](state)) # b k d
|
state = self.activation(self.ln[i](state)) # b k d
|
||||||
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||||
@ -219,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorHead(nn.Module):
|
class MLPSpeculatorHead(nn.Module):
|
||||||
def __init__(self, lm_head, mlp_speculator):
|
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lm_head = lm_head
|
self.lm_head = lm_head
|
||||||
self.mlp_speculator = mlp_speculator
|
self.mlp_speculator = mlp_speculator
|
||||||
|
self.scale_input = scale_input
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor
|
self, input: torch.Tensor
|
||||||
@ -233,6 +248,8 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
return logits, None
|
return logits, None
|
||||||
|
|
||||||
input_ids = logits.argmax(dim=-1)
|
input_ids = logits.argmax(dim=-1)
|
||||||
|
if self.scale_input:
|
||||||
|
input = simple_norm(input)
|
||||||
speculative_logits = self.mlp_speculator(input, input_ids)
|
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@ -259,5 +276,7 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
|
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
|
||||||
else:
|
else:
|
||||||
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
|
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
|
||||||
|
scale_input = config.speculator_config.get("scale_input", False)
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
|
||||||
|
Loading…
Reference in New Issue
Block a user