diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 6e4a13cd..0fccbd75 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,4 +1,7 @@ +import json import os +from pathlib import Path + import torch import torch.distributed @@ -439,6 +442,197 @@ class ResBlock(torch.nn.Module): def forward(self, x): return x + self.act(self.linear(x)) +class LayerNormParameterized(nn.Module): + """ + A generalized LayerNorm implementation. With all optional arguments set to True, equivalent to nn.LayerNorm up to epsilon stabilization term + (this class divides inputs by min(norm, eps), while nn.LayerNorm divides by norm + eps). + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + eps : float + Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). + elementwise_scale : bool + Include a learned scaling term after normalization? + elementwise_shift : bool + Include a learned bias term after normalization? + use_mean : bool + Recenter inputs around zero before normalizing, or just rescale? + """ + + def __init__( + self, + normalized_shape, + eps=1e-06, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + use_high_precision_pow=False, + ): + super(LayerNormParameterized, self).__init__() + self.normalized_shape = normalized_shape + self.eps = eps + self.elementwise_scale = elementwise_scale + self.elementwise_shift = elementwise_shift + self.use_mean = use_mean + self.use_high_precision_pow = use_high_precision_pow + + if self.elementwise_scale: + self.weight = nn.Parameter(torch.empty(self.normalized_shape)) + if self.elementwise_shift: + self.bias = nn.Parameter(torch.empty(self.normalized_shape)) + + def reset_parameters(self): + if self.elementwise_scale: + self.weight.data.fill_(1) + if self.elementwise_shift: + self.bias.data.zero_() + + def forward(self, x): + if self.use_mean: + x = x - x.mean(-1, keepdim=True) + xf = x + if self.use_high_precision_pow: + xf = x.float() + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + if self.elementwise_scale: + x = self.weight * x + if self.elementwise_shift: + x = x + self.bias + return x + +class MLPSpeculatorModel(torch.nn.Module): + def __init__(self, config, emb, proj, head, ln): + super().__init__() + self.config = config + self.n_predict = config.n_predict + self.emb_dim = config.emb_dim + inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim + self.inner_dim = inner_dim + self.config = config.vocab_size + self.emb = emb + self.proj = proj + self.head = head + self.ln = ln + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation + self.state_weight = 0.5 ** (0.5 / self.n_predict) + self.emb_weight = math.sqrt(1 - self.state_weight ** 2) + self.activation = nn.GELU() + + def forward(self, state: torch.Tensor, ind: torch.Tensor, top_k_tokens_per_head: Optional[List[int]], num_candidates: int = 1): + if top_k_tokens_per_head is None: + top_k_tokens_per_head = self.config.top_k_tokens_per_head + + # k indicates # of candidates + # h indicates # of generated tokens + b = state.size(0) + out = torch.empty(b, 1, 0, device=state.device).int() # b k h + log_probs = torch.zeros(b, 1, device=state.device) # b k + all_probs = torch.empty(b, 1, 0, self.vsize, device=state.device) # b k h v + assert ( + len(top_k_tokens_per_head) == self.n_predict + ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" + for i in range(self.n_predict): + # Project and predict + z = self.emb[i](ind) + z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d + state = self.proj[i](state) * self.state_weight + z + state = self.activation(self.ln[i](state)) # b k d + _probs = F.log_softmax(self.head[i](state), dim=2) # b k v + probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=2) # b k k' + + # Update candidate set with new predictions + out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h + out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1 + out = out.view(b, -1, i + 1) # b kk' h+1 + + # Update distribution set with new logits + all_probs = torch.cat([all_probs, _probs.exp().unsqueeze(2)], dim=2) # b k h+1 v + all_probs = all_probs.repeat(1, top_k_tokens_per_head[i], 1, 1) # b kk' h+1 v + + # Update state, log_probs and ind for new predictions + state = state.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' d + state = state.reshape(b, -1, state.size(3)) # b kk' d + ind = preds.view(b, -1) # b kk' + log_probs = log_probs.unsqueeze(2).expand(b, -1, top_k_tokens_per_head[i]) # b k k' + log_probs = log_probs.add(probs).reshape(b, -1) # b kk' + + # Take only top n best guesses + best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k + return all_probs.gather( + 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize) + ) # b n h v + + def load(self, config, prefix, weights): + self.emb = nn.ModuleList( + [nn.Embedding(config.vocab_size, config.inner_dim) for _ in range(config.n_predict)] + ) + self.proj = nn.ModuleList( + [ + nn.Linear((config.emb_dim if i == 0 else config.inner_dim), config.inner_dim, bias=False) + for i in range(config.n_predict) + ] + ) + self.head = nn.ModuleList( + [nn.Linear(config.inner_dim, config.vocab_size, bias=False) for _ in range(config.n_predict)] + ) + self.ln = nn.ModuleList( + [ + LayerNormParameterized( + config.inner_dim, elementwise_shift=True, elementwise_scale=True + ) + for _ in range(config.n_predict) + ] + ) + for i in range(config.n_predict): + self.emb[i].weight.data.copy_(weights.get_tensor(f"{prefix}.emb.{i}.weight")) + self.proj[i].weight.data.copy_(weights.get_tensor(f"{prefix}.proj.{i}.weight")) + self.ln[i].weight.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.weight")) + self.ln[i].bias.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.bias")) + self.head[i].weight.data.copy_(weights.get_tensor(f"{prefix}.head.{i}.weight")) + + +class MLPSpeculatorHeadV1(nn.Module): + def __init__(self, lm_head, mlp_speculator): + super().__init__() + self.lm_head = lm_head + self.mlp_speculator = mlp_speculator + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + # If we have too many tokens, we skip speculative logits + if input.shape[0] > 128: + return logits, None + + speculative_logits = self.mlp_speculator(input) + return logits, speculative_logits + + @staticmethod + def load(speculator_config, prefix: str, weights): + from pathlib import Path + from safetensors import safe_open + + speculator_path = speculator_config.use_speculator + + filename = str(Path(speculator_path) / "*.safetensors") + + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + mlp_speculator = MLPSpeculatorModel.load(speculator_config, prefix, weights) + lm_head = TensorParallelHead.load(speculator_config, prefix, weights) + return MLPSpeculatorHeadV1(lm_head, mlp_speculator) + class MedusaModel(torch.nn.Module): def __init__(self, config, medusa_config, weights): @@ -606,24 +800,35 @@ class MedusaHeadV2(nn.Module): class SpeculativeHead(nn.Module): - def __init__(self, lm_head, medusa): + def __init__(self, lm_head, speculator): super().__init__() self.head = lm_head - self.medusa = medusa + self.speculator = speculator @staticmethod def load(config, prefix: str, weights): - use_medusa = config.use_medusa - if use_medusa: + use_speculator = config.use_speculator + if use_speculator: + speculator_config = str(Path(use_speculator) / "config.json") + + with open(speculator_config, "r") as f: + speculator_config = json.load(f) lm_head = None - try: - medusa = MedusaHeadV1.load(config, prefix, weights) - except: - medusa = MedusaHeadV2(config, prefix, weights) + + architecture = speculator_config["architectures"][0] + + if architecture == "MLPSpeculatorPreTrainedModel": + speculator_config.use_speculator = config.use_speculator + speculator = MLPSpeculatorHeadV1.load(speculator_config, "speculator", weights) + else: # not sure what medusa name is... + try: + speculator = MedusaHeadV1.load(config, prefix, weights) + except: + speculator = MedusaHeadV2(config, prefix, weights) else: lm_head = TensorParallelHead.load(config, prefix, weights) - medusa = None - return SpeculativeHead(lm_head, medusa) + speculator = None + return SpeculativeHead(lm_head, speculator) def forward( self, input: torch.Tensor