From a313355d2b32778a41fa91eefe7e8060373ed4fc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 29 Aug 2024 17:44:54 +0200 Subject: [PATCH] Tied embeddings in MLP speculator. (#2473) * Tied embeddings in MLP speculator. * Fixing the scale_weight when users decide to not use the speculation as much as defined in the config. * Adding scaling support + optimize some ops. --- server/text_generation_server/layers/mlp.py | 120 +++++++++++++++++- .../text_generation_server/models/__init__.py | 5 + 2 files changed, 118 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/layers/mlp.py b/server/text_generation_server/layers/mlp.py index f08cb673..d33b41f3 100644 --- a/server/text_generation_server/layers/mlp.py +++ b/server/text_generation_server/layers/mlp.py @@ -45,12 +45,107 @@ class MLPSpeculatorLayerNorm(nn.Module): return x +INV_SQRT2 = 2**-0.5 + + +def simple_norm(x: torch.Tensor, eps=1e-06): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps) + x = xf.type_as(x) + return x * INV_SQRT2 + + +class MLPSpeculatorModelTied(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.config = config + self.n_predict = get_speculate() + self.hidden_size = config.hidden_size + + self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights) + self.proj0 = FastLinear.load( + config, + prefix=f"{prefix}.proj.0", + weights=weights, + bias=False, + ) + self.proj1 = FastLinear.load( + config, + prefix=f"{prefix}.proj.1", + weights=weights, + bias=False, + ) + self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False) + self.ln = MLPSpeculatorLayerNorm( + prefix=f"{prefix}.ln.0", + config=config, + weights=weights, + ) + + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation + self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 + self.activation = nn.GELU() + self.vsize = config.vocab_size + self.inner_dim = config.speculator_config["inner_dim"] + self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + ): + top_k_tokens_per_head = self.top_k_tokens_per_head + + # k indicates # of candidates + # h indicates # of generated tokens + state = hidden_states + b = state.size(0) + ind = input_ids.unsqueeze(0) + all_probs = torch.empty( + b, self.n_predict, self.vsize, device=state.device + ) # b k h v + assert ( + len(top_k_tokens_per_head) == self.n_predict + ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" + for i in range(self.n_predict): + # Project and predict + z = self.emb(ind) + # z = z.mul(self.emb_weight) # b k d + if i == 0: + state = self.proj0(state) * self.state_weight + z + else: + state = self.proj1(state) * self.state_weight + z + state = self.activation(self.ln(state)) # b k d + probs = F.log_softmax(self.head(state), dim=-1) # b k v + _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' + + # Update candidate set with new predictions + + # Update distribution set with new logits + all_probs[:, i] = probs.exp() + + # Update state, log_probs and ind for new predictions + state = state.unsqueeze(2).expand( + -1, -1, top_k_tokens_per_head[i], -1 + ) # b k k' d + state = state.reshape(-1, b, state.size(3)) # b kk' d + ind = preds.view(-1, b) # b kk' + + speculative_logits = all_probs + return speculative_logits + + class MLPSpeculatorModel(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.config = config self.n_predict = get_speculate() self.hidden_size = config.hidden_size + self.emb = nn.ModuleList( [ TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) @@ -84,13 +179,15 @@ class MLPSpeculatorModel(torch.nn.Module): ) # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation - self.state_weight = 0.5 ** (0.5 / self.n_predict) - self.emb_weight = math.sqrt(1 - self.state_weight**2) + self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 self.activation = nn.GELU() - # TODO self.vsize = config.vocab_size self.inner_dim = config.speculator_config["inner_dim"] self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight def forward( self, @@ -113,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module): for i in range(self.n_predict): # Project and predict z = self.emb[i](ind) - z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d + # z = z.mul(self.emb_weight) # b k d state = self.proj[i](state) * self.state_weight + z state = self.activation(self.ln[i](state)) # b k d probs = F.log_softmax(self.head[i](state), dim=-1) # b k v @@ -136,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module): class MLPSpeculatorHead(nn.Module): - def __init__(self, lm_head, mlp_speculator): + def __init__(self, lm_head, mlp_speculator, scale_input: bool): super().__init__() self.lm_head = lm_head self.mlp_speculator = mlp_speculator + self.scale_input = scale_input def forward( self, input: torch.Tensor @@ -150,6 +248,8 @@ class MLPSpeculatorHead(nn.Module): return logits, None input_ids = logits.argmax(dim=-1) + if self.scale_input: + input = simple_norm(input) speculative_logits = self.mlp_speculator(input, input_ids) return logits, speculative_logits @@ -171,6 +271,12 @@ class MLPSpeculatorHead(nn.Module): ) routing[k] = filename - mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + tie_weights = config.speculator_config.get("tie_weights", False) + if tie_weights: + mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights) + else: + mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator + scale_input = config.speculator_config.get("scale_input", False) lm_head = TensorParallelHead.load(config, prefix, weights) - return MLPSpeculatorHead(lm_head, mlp_speculator) + return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e03cc30d..52f332c1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -458,6 +458,11 @@ def get_model( revision=mlp_revision, filename=filename, ) + speculator_dir_path = Path(mlp_speculator_config).parent + # if these are downloaded, they get converted to safetensors + filenames.extend( + [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)] + ) speculator = { "path": Path(mlp_speculator_config).parent, "model_paths": filenames,