From b8848990868cc66f921edc9b82597957c9a2605c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 8 May 2024 12:20:00 +0000 Subject: [PATCH] Removed a bunch of hardcodes. --- .../text_generation_server/models/__init__.py | 48 ++++++++++++----- server/text_generation_server/utils/layers.py | 54 ++++++------------- 2 files changed, 51 insertions(+), 51 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 05adc18e..fcecf8af 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,9 +1,10 @@ import torch +import os from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path @@ -166,9 +167,15 @@ def get_model( revision=medusa_revision, filename="medusa_lm_head.safetensors", ) - speculator = Path(medusa_config).parent + speculator = { + "path": Path(medusa_config).parent, + "model_paths": ["medusa_lm_head.safetensors"], + } else: - speculator = Path(medusa_model_id) + speculator = { + "path": Path(medusa_model_id), + "model_paths": ["medusa_lm_head.safetensors"], + } method = "medusa" elif config_dict["model_type"] == "mlp_speculator": @@ -192,23 +199,36 @@ def get_model( model_id, revision=revision, trust_remote_code=trust_remote_code ) is_local = Path(mlp_model_id).exists() + extension = ".safetensors" if not is_local: mlp_speculator_config = hf_hub_download( mlp_model_id, revision=mlp_revision, filename="config.json" ) - hf_hub_download( - mlp_model_id, - revision=mlp_revision, - filename="model-00001-of-00002.safetensors", - ) - hf_hub_download( - mlp_model_id, - revision=mlp_revision, - filename="model-00002-of-00002.safetensors", - ) - speculator = Path(mlp_speculator_config).parent + api = HfApi() + info = api.model_info(mlp_model_id, revision=mlp_revision) + filenames = [ + s.rfilename + for s in info.siblings + if s.rfilename.endswith(extension) + and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename + and "args" not in s.rfilename + and "training" not in s.rfilename + ] + for filename in filenames: + hf_hub_download( + mlp_model_id, + revision=mlp_revision, + filename=filename, + ) + speculator = { + "path": Path(mlp_speculator_config).parent, + "model_paths": filenames, + } else: speculator = Path(mlp_model_id) + filenames = [p for p in os.listdir(speculator) if p.endswith(extension)] + speculator = {"path": speculator, "model_paths": filenames} method = "mlp_speculator" else: method = "n-gram" diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index af00f5a3..1f5f35c7 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -525,10 +525,9 @@ class MLPSpeculatorModel(torch.nn.Module): self.emb_weight = math.sqrt(1 - self.state_weight**2) self.activation = nn.GELU() # TODO - self.vsize = 128256 - self.inner_dim = 3072 + self.vsize = config.vocab_size + self.inner_dim = config.speculator_config["inner_dim"] self.top_k_tokens_per_head = [1] * self.n_predict - self.candidates = 1 def forward( self, @@ -536,27 +535,20 @@ class MLPSpeculatorModel(torch.nn.Module): input_ids: torch.Tensor, ): top_k_tokens_per_head = self.top_k_tokens_per_head - num_candidates = self.candidates - - # if state.shape[0] > 1: - # state = state[:1] # k indicates # of candidates # h indicates # of generated tokens state = hidden_states b = state.size(0) ind = input_ids.unsqueeze(0) - out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h - # log_probs = torch.zeros(1, b, device=state.device) # b k all_probs = torch.empty( - 1, b, self.n_predict, self.vsize, device=state.device + b, self.n_predict, self.vsize, device=state.device ) # b k h v assert ( len(top_k_tokens_per_head) == self.n_predict ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" for i in range(self.n_predict): # Project and predict - # print(ind) z = self.emb[i](ind) z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d state = self.proj[i](state) * self.state_weight + z @@ -565,10 +557,9 @@ class MLPSpeculatorModel(torch.nn.Module): _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' # Update candidate set with new predictions - out[:, :, i : i + 1] = preds # Update distribution set with new logits - all_probs[:, :, i] = probs.exp() + all_probs[:, i] = probs.exp() # Update state, log_probs and ind for new predictions state = state.unsqueeze(2).expand( @@ -576,20 +567,8 @@ class MLPSpeculatorModel(torch.nn.Module): ) # b k k' d state = state.reshape(-1, b, state.size(3)) # b kk' d ind = preds.view(-1, b) # b kk' - # log_probs = log_probs.unsqueeze(2).expand( - # -1, b, top_k_tokens_per_head[i] - # ) # b k k' - # log_probs = log_probs.add(probs).reshape(-1, b) # b kk' - # print("done") - # Take only top n best guesses - # best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k - # speculative_logits = all_probs.gather( - # 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize) - # ).squeeze(0) - speculative_logits = all_probs[0] - # assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}" - # TODO Why is this shift existing, are speculative logits also including the natural next token ? + speculative_logits = all_probs return speculative_logits @@ -612,16 +591,13 @@ class MLPSpeculatorHead(nn.Module): return logits, speculative_logits @staticmethod - def load(speculator_config, prefix: str, weights): + def load(config, prefix: str, weights): from pathlib import Path from safetensors import safe_open - speculator_path = speculator_config.speculator + speculator_path = config.speculator["path"] - for fname in [ - "model-00001-of-00002.safetensors", - "model-00002-of-00002.safetensors", - ]: + for fname in config.speculator["model_paths"]: filename = str(Path(speculator_path) / fname) routing = weights.routing with safe_open(filename, framework="pytorch") as f: @@ -632,8 +608,8 @@ class MLPSpeculatorHead(nn.Module): ) routing[k] = filename - mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights) - lm_head = TensorParallelHead.load(speculator_config, prefix, weights) + mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + lm_head = TensorParallelHead.load(config, prefix, weights) return MLPSpeculatorHead(lm_head, mlp_speculator) @@ -726,8 +702,9 @@ class MedusaHeadV2(nn.Module): speculator = config.speculator - medusa_config = str(Path(speculator) / "config.json") - filename = str(Path(speculator) / "medusa_lm_head.safetensors") + path = Path(speculator["path"]) + medusa_config = str(path / "config.json") + filename = path / speculator["model_paths"][0] with open(medusa_config, "r") as f: medusa_config = json.load(f) @@ -812,11 +789,14 @@ class SpeculativeHead(nn.Module): def load(config, prefix: str, weights): speculator = config.speculator if speculator: - speculator_config = str(Path(speculator) / "config.json") + + speculator_path = config.speculator["path"] + speculator_config = str(speculator_path / "config.json") with open(speculator_config, "r") as f: speculator_config = json.load(f) lm_head = None + config.speculator_config = speculator_config # currently medusa does not have an architecture specified, so try-except for now # this should really be handled in a better way though (maybe the classname can be part of the config)