diff --git a/server/text_generation_server/layers/mlp.py b/server/text_generation_server/layers/mlp.py index e40b2fc6..f08cb673 100644 --- a/server/text_generation_server/layers/mlp.py +++ b/server/text_generation_server/layers/mlp.py @@ -1,4 +1,5 @@ import torch +import math from torch import nn from torch.nn import functional as F from typing import Optional, Tuple @@ -6,6 +7,7 @@ from text_generation_server.layers import TensorParallelEmbedding, FastLinear from text_generation_server.layers.tensor_parallel import TensorParallelHead from text_generation_server.utils.speculate import get_speculate + class MLPSpeculatorLayerNorm(nn.Module): """ A L2 normalization implementation @@ -140,7 +142,7 @@ class MLPSpeculatorHead(nn.Module): self.mlp_speculator = mlp_speculator def forward( - self, input: torch.Tensor, input_ids: torch.Tensor + self, input: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: logits = self.lm_head(input) # If we have too many tokens, we skip speculative logits @@ -172,4 +174,3 @@ class MLPSpeculatorHead(nn.Module): mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) lm_head = TensorParallelHead.load(config, prefix, weights) return MLPSpeculatorHead(lm_head, mlp_speculator) - diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 183625bf..a7969494 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -419,6 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - # input_ids = input_ids[lm_head_indices] - logits, speculative_logits = self.lm_head(hidden_states, input_ids) + logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dc1f7249..3e13c26d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -480,5 +480,5 @@ class FlashMistralForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states, input_ids) + logits = self.lm_head(hidden_states) return logits