diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c5fd0b2c..183625bf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -419,5 +419,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + # input_ids = input_ids[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states, input_ids) return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3e13c26d..dc1f7249 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -480,5 +480,5 @@ class FlashMistralForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states, input_ids) return logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f567bea9..5aa7a568 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1101,6 +1101,8 @@ class FlashCausalLM(Model): next_token_texts = [] left = 0 + logger.info(f"Accepted ids {n_accepted_ids}") + current_stopped = False for j in range(index, index + n_accepted_ids): # Generated token diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 48304ad8..b83f49a4 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM): config_cls=AutoConfig, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, @@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator # Set context windows if getattr(config, "sliding_window", None) is not None: @@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index b0b271f5..a3515aa1 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,5 +1,6 @@ import json import os +import math from pathlib import Path import torch @@ -494,7 +495,7 @@ class MLPSpeculatorModel(torch.nn.Module): ] ) self.proj = [ - TensorParallelColumnLinear.load( + FastLinear.load( config, prefix=f"{prefix}.proj.{i}", weights=weights, @@ -504,9 +505,7 @@ class MLPSpeculatorModel(torch.nn.Module): ] self.head = nn.ModuleList( [ - TensorParallelRowLinear.load( - config, f"{prefix}.head.{i}", weights, bias=False - ) + FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(self.n_predict) ] ) @@ -528,32 +527,36 @@ class MLPSpeculatorModel(torch.nn.Module): # TODO self.vsize = 128256 self.inner_dim = 3072 + self.top_k_tokens_per_head = [1] * self.n_predict + self.candidates = 1 def forward( self, - state: torch.Tensor, + hidden_states: torch.Tensor, input_ids: torch.Tensor, - top_k_tokens_per_head: Optional[List[int]] = None, - num_candidates: int = 1, ): - # TODO - top_k_tokens_per_head = [1, 1, 1, 1] - if top_k_tokens_per_head is None: - top_k_tokens_per_head = self.config.top_k_tokens_per_head + top_k_tokens_per_head = self.top_k_tokens_per_head + num_candidates = self.candidates - ind = input_ids + # if state.shape[0] > 1: + # state = state[:1] # k indicates # of candidates # h indicates # of generated tokens + state = hidden_states b = state.size(0) - out = torch.empty(b, 1, 0, device=state.device).int() # b k h - log_probs = torch.zeros(b, 1, device=state.device) # b k - all_probs = torch.empty(b, 1, 0, self.vsize, device=state.device) # b k h v + ind = input_ids[-b:].unsqueeze(0) + out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h + log_probs = torch.zeros(1, b, device=state.device) # b k + all_probs = torch.empty( + 1, b, self.n_predict, self.vsize, device=state.device + ) # b k h v assert ( len(top_k_tokens_per_head) == self.n_predict ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" for i in range(self.n_predict): # Project and predict + # print(ind) z = self.emb[i](ind) z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d state = self.proj[i](state) * self.state_weight + z @@ -562,43 +565,32 @@ class MLPSpeculatorModel(torch.nn.Module): probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' # Update candidate set with new predictions - out = out.unsqueeze(2).expand( - -1, -1, top_k_tokens_per_head[i], -1 - ) # b k k' h - try: - out = torch.cat( - [out, preds.unsqueeze(2).unsqueeze(3)], dim=-1 - ) # b k k' h+1 - except Exception: - import ipdb - - ipdb.set_trace() - out = out.view(b, -1, i + 1) # b kk' h+1 + out[:, :, i : i + 1] = preds # Update distribution set with new logits - all_probs = torch.cat( - [all_probs, _probs.exp().unsqueeze(2)], dim=-1 - ) # b k h+1 v - all_probs = all_probs.repeat( - 1, top_k_tokens_per_head[i], 1, 1 - ) # b kk' h+1 v + all_probs[:, :, i] = _probs.exp() # Update state, log_probs and ind for new predictions state = state.unsqueeze(2).expand( -1, -1, top_k_tokens_per_head[i], -1 ) # b k k' d - state = state.reshape(b, -1, state.size(3)) # b kk' d - ind = preds.view(b, -1) # b kk' + state = state.reshape(-1, b, state.size(3)) # b kk' d + ind = preds.view(-1, b) # b kk' log_probs = log_probs.unsqueeze(2).expand( - b, -1, top_k_tokens_per_head[i] + -1, b, top_k_tokens_per_head[i] ) # b k k' - log_probs = log_probs.add(probs).reshape(b, -1) # b kk' + log_probs = log_probs.add(probs).reshape(-1, b) # b kk' + # print("done") # Take only top n best guesses best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k - return all_probs.gather( - 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize) - ) # b n h v + # speculative_logits = all_probs.gather( + # 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize) + # ).squeeze(0) + speculative_logits = all_probs[0] + # assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}" + # TODO Why is this shift existing, are speculative logits also including the natural next token ? + return speculative_logits[:, 1:] class MLPSpeculatorHead(nn.Module): @@ -692,10 +684,10 @@ class MedusaHeadV1(nn.Module): from safetensors import safe_open import json - use_medusa = config.use_medusa + speculator = config.speculator - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + medusa_config = str(Path(speculator) / "config.json") + filename = str(Path(speculator) / "medusa_lm_head.safetensors") with open(medusa_config, "r") as f: medusa_config = json.load(f) @@ -713,7 +705,7 @@ class MedusaHeadV1(nn.Module): return MedusaHeadV1(lm_head, medusa) def forward( - self, input: torch.Tensor + self, input: torch.Tensor, _input_ids: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: logits = self.lm_head(input) # If we have too many tokens, we skip speculative logits @@ -731,10 +723,10 @@ class MedusaHeadV2(nn.Module): from safetensors import safe_open import json - use_medusa = config.use_medusa + speculator = config.speculator - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + medusa_config = str(Path(speculator) / "config.json") + filename = str(Path(speculator) / "medusa_lm_head.safetensors") with open(medusa_config, "r") as f: medusa_config = json.load(f) @@ -765,7 +757,7 @@ class MedusaHeadV2(nn.Module): self.lm_head = TensorParallelHead.load(config, prefix, weights) - def forward(self, x): + def forward(self, x, _input_ids): # If we have too many tokens, we skip speculative logits if x.shape[0] > 128: logits = self.lm_head(x)