From 38d60454436b40529bb10cdb87ef0665ad4a3abf Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 6 May 2024 14:03:05 +0000 Subject: [PATCH] Hardcode a few stuff to make it work. --- .../text_generation_server/models/__init__.py | 121 +++++++++----- .../custom_modeling/flash_llama_modeling.py | 2 +- .../models/flash_llama.py | 4 +- server/text_generation_server/utils/layers.py | 152 +++++++++++------- 4 files changed, 179 insertions(+), 100 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index b52765d7..05adc18e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -136,7 +136,7 @@ def get_model( model_id, revision=revision, trust_remote_code=trust_remote_code ) - use_medusa = None + speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id medusa_revision = revision @@ -166,11 +166,50 @@ def get_model( revision=medusa_revision, filename="medusa_lm_head.safetensors", ) - use_medusa = Path(medusa_config).parent + speculator = Path(medusa_config).parent else: - use_medusa = Path(medusa_model_id) + speculator = Path(medusa_model_id) method = "medusa" + elif config_dict["model_type"] == "mlp_speculator": + # TODO make this not hardcoded. + mlp_model_id = model_id + mlp_revision = revision + model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + revision = "main" + speculate_mlp = config_dict["n_predict"] + if speculate is not None: + if speculate > speculate_mlp: + raise RuntimeError( + f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match" + ) + else: + set_speculate(speculate) + else: + set_speculate(speculate_mlp) + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + is_local = Path(mlp_model_id).exists() + if not is_local: + mlp_speculator_config = hf_hub_download( + mlp_model_id, revision=mlp_revision, filename="config.json" + ) + hf_hub_download( + mlp_model_id, + revision=mlp_revision, + filename="model-00001-of-00002.safetensors", + ) + hf_hub_download( + mlp_model_id, + revision=mlp_revision, + filename="model-00002-of-00002.safetensors", + ) + speculator = Path(mlp_speculator_config).parent + else: + speculator = Path(mlp_model_id) + method = "mlp_speculator" else: method = "n-gram" @@ -202,7 +241,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -212,7 +251,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -227,7 +266,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -240,7 +279,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -250,7 +289,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -259,7 +298,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -270,7 +309,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -279,7 +318,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -288,7 +327,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -299,7 +338,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -308,7 +347,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -323,7 +362,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -334,7 +373,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -345,7 +384,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -355,7 +394,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -366,7 +405,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -377,7 +416,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -388,7 +427,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -399,7 +438,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -410,7 +449,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -424,7 +463,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -435,7 +474,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -444,7 +483,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -458,7 +497,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -469,7 +508,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -483,7 +522,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -494,7 +533,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -520,7 +559,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -544,7 +583,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -554,7 +593,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -564,7 +603,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -574,7 +613,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -586,7 +625,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -599,7 +638,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -623,7 +662,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -632,7 +671,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -644,7 +683,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -653,7 +692,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a7969494..c5fd0b2c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -419,5 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states, input_ids) return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 8ea70713..796fbd47 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c3080937..b0b271f5 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -442,6 +442,7 @@ class ResBlock(torch.nn.Module): def forward(self, x): return x + self.act(self.linear(x)) + class MLPSpeculatorLayerNorm(nn.Module): """ A L2 normalization implementation @@ -461,15 +462,14 @@ class MLPSpeculatorLayerNorm(nn.Module): def __init__( self, - normalized_shape, - elementwise_scale_weight: torch.Tensor, - elementwise_shift_bias: torch.Tensor, + prefix, + config, + weights, eps=1e-06, ): super(MLPSpeculatorLayerNorm, self).__init__() - self.normalized_shape = normalized_shape - self.weight = nn.Parameter(elementwise_scale_weight) - self.bias = nn.Parameter(elementwise_shift_bias) + self.weight = weights.get_tensor(f"{prefix}.weight") + self.bias = weights.get_tensor(f"{prefix}.bias") self.eps = eps def forward(self, x): @@ -480,48 +480,69 @@ class MLPSpeculatorLayerNorm(nn.Module): x = x + self.bias return x + class MLPSpeculatorModel(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.config = config - self.n_predict = config.n_predict - self.emb_dim = config.emb_dim - inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim - self.inner_dim = inner_dim - self.config = config.vocab_size + self.n_predict = get_speculate() + self.hidden_size = config.hidden_size self.emb = nn.ModuleList( - [TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)] - ) - self.proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)], - weights=weights, - bias=False, - dim=0 + [ + TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) + for i in range(self.n_predict) + ] ) + self.proj = [ + TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.proj.{i}", + weights=weights, + bias=False, + ) + for i in range(self.n_predict) + ] self.head = nn.ModuleList( - [TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)] + [ + TensorParallelRowLinear.load( + config, f"{prefix}.head.{i}", weights, bias=False + ) + for i in range(self.n_predict) + ] ) self.ln = nn.ModuleList( [ MLPSpeculatorLayerNorm( - config.inner_dim, - elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"), - elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias") + prefix=f"{prefix}.ln.{i}", + config=config, + weights=weights, ) - for i in range(config.n_predict) + for i in range(self.n_predict) ] ) # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation self.state_weight = 0.5 ** (0.5 / self.n_predict) - self.emb_weight = math.sqrt(1 - self.state_weight ** 2) + self.emb_weight = math.sqrt(1 - self.state_weight**2) self.activation = nn.GELU() + # TODO + self.vsize = 128256 + self.inner_dim = 3072 - def forward(self, state: torch.Tensor, ind: torch.Tensor, top_k_tokens_per_head: Optional[List[int]], num_candidates: int = 1): + def forward( + self, + state: torch.Tensor, + input_ids: torch.Tensor, + top_k_tokens_per_head: Optional[List[int]] = None, + num_candidates: int = 1, + ): + # TODO + top_k_tokens_per_head = [1, 1, 1, 1] if top_k_tokens_per_head is None: top_k_tokens_per_head = self.config.top_k_tokens_per_head + ind = input_ids + # k indicates # of candidates # h indicates # of generated tokens b = state.size(0) @@ -529,7 +550,7 @@ class MLPSpeculatorModel(torch.nn.Module): log_probs = torch.zeros(b, 1, device=state.device) # b k all_probs = torch.empty(b, 1, 0, self.vsize, device=state.device) # b k h v assert ( - len(top_k_tokens_per_head) == self.n_predict + len(top_k_tokens_per_head) == self.n_predict ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" for i in range(self.n_predict): # Project and predict @@ -537,23 +558,40 @@ class MLPSpeculatorModel(torch.nn.Module): z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d state = self.proj[i](state) * self.state_weight + z state = self.activation(self.ln[i](state)) # b k d - _probs = F.log_softmax(self.head[i](state), dim=2) # b k v - probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=2) # b k k' + _probs = F.log_softmax(self.head[i](state), dim=-1) # b k v + probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' # Update candidate set with new predictions - out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h - out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1 + out = out.unsqueeze(2).expand( + -1, -1, top_k_tokens_per_head[i], -1 + ) # b k k' h + try: + out = torch.cat( + [out, preds.unsqueeze(2).unsqueeze(3)], dim=-1 + ) # b k k' h+1 + except Exception: + import ipdb + + ipdb.set_trace() out = out.view(b, -1, i + 1) # b kk' h+1 # Update distribution set with new logits - all_probs = torch.cat([all_probs, _probs.exp().unsqueeze(2)], dim=2) # b k h+1 v - all_probs = all_probs.repeat(1, top_k_tokens_per_head[i], 1, 1) # b kk' h+1 v + all_probs = torch.cat( + [all_probs, _probs.exp().unsqueeze(2)], dim=-1 + ) # b k h+1 v + all_probs = all_probs.repeat( + 1, top_k_tokens_per_head[i], 1, 1 + ) # b kk' h+1 v # Update state, log_probs and ind for new predictions - state = state.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' d + state = state.unsqueeze(2).expand( + -1, -1, top_k_tokens_per_head[i], -1 + ) # b k k' d state = state.reshape(b, -1, state.size(3)) # b kk' d ind = preds.view(b, -1) # b kk' - log_probs = log_probs.unsqueeze(2).expand(b, -1, top_k_tokens_per_head[i]) # b k k' + log_probs = log_probs.unsqueeze(2).expand( + b, -1, top_k_tokens_per_head[i] + ) # b k k' log_probs = log_probs.add(probs).reshape(b, -1) # b kk' # Take only top n best guesses @@ -570,14 +608,14 @@ class MLPSpeculatorHead(nn.Module): self.mlp_speculator = mlp_speculator def forward( - self, input: torch.Tensor + self, input: torch.Tensor, input_ids: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: logits = self.lm_head(input) # If we have too many tokens, we skip speculative logits if input.shape[0] > 128: return logits, None - speculative_logits = self.mlp_speculator(input) + speculative_logits = self.mlp_speculator(input, input_ids) return logits, speculative_logits @staticmethod @@ -585,18 +623,21 @@ class MLPSpeculatorHead(nn.Module): from pathlib import Path from safetensors import safe_open - speculator_path = speculator_config.use_speculator + speculator_path = speculator_config.speculator - filename = str(Path(speculator_path) / "*.safetensors") - - routing = weights.routing - with safe_open(filename, framework="pytorch") as f: - for k in f.keys(): - if k in routing and routing[k] != filename: - raise RuntimeError( - f"Key {k} was found in multiple files: {filename} and {routing[k]}" - ) - routing[k] = filename + for fname in [ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + ]: + filename = str(Path(speculator_path) / fname) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights) lm_head = TensorParallelHead.load(speculator_config, prefix, weights) @@ -776,9 +817,9 @@ class SpeculativeHead(nn.Module): @staticmethod def load(config, prefix: str, weights): - use_speculator = config.use_speculator - if use_speculator: - speculator_config = str(Path(use_speculator) / "config.json") + speculator = config.speculator + if speculator: + speculator_config = str(Path(speculator) / "config.json") with open(speculator_config, "r") as f: speculator_config = json.load(f) @@ -790,8 +831,7 @@ class SpeculativeHead(nn.Module): architecture = speculator_config["architectures"][0] if architecture == "MLPSpeculatorPreTrainedModel": - speculator_config.use_speculator = config.use_speculator - speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights) + speculator = MLPSpeculatorHead.load(config, prefix, weights) else: speculator = None except KeyError: @@ -805,10 +845,10 @@ class SpeculativeHead(nn.Module): return SpeculativeHead(lm_head, speculator) def forward( - self, input: torch.Tensor + self, input: torch.Tensor, input_ids: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.medusa is not None: - return self.medusa(input) + if self.speculator is not None: + return self.speculator(input, input_ids) assert self.head is not None logits = self.head(input)