Rebase after refactor.

This commit is contained in:
Nicolas Patry 2024-05-13 12:44:06 +00:00
parent b884899086
commit 71a535e401
3 changed files with 5 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import torch
import math
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
@ -6,6 +7,7 @@ from text_generation_server.layers import TensorParallelEmbedding, FastLinear
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.utils.speculate import get_speculate
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
@ -140,7 +142,7 @@ class MLPSpeculatorHead(nn.Module):
self.mlp_speculator = mlp_speculator
def forward(
self, input: torch.Tensor, input_ids: torch.Tensor
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
@ -172,4 +174,3 @@ class MLPSpeculatorHead(nn.Module):
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator)

View File

@ -419,6 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
# input_ids = input_ids[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -480,5 +480,5 @@ class FlashMistralForCausalLM(torch.nn.Module):
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states, input_ids)
logits = self.lm_head(hidden_states)
return logits