mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Rebase after refactor.
This commit is contained in:
parent
b884899086
commit
71a535e401
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional, Tuple
|
||||
@ -6,6 +7,7 @@ from text_generation_server.layers import TensorParallelEmbedding, FastLinear
|
||||
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
|
||||
|
||||
class MLPSpeculatorLayerNorm(nn.Module):
|
||||
"""
|
||||
A L2 normalization implementation
|
||||
@ -140,7 +142,7 @@ class MLPSpeculatorHead(nn.Module):
|
||||
self.mlp_speculator = mlp_speculator
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_ids: torch.Tensor
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
@ -172,4 +174,3 @@ class MLPSpeculatorHead(nn.Module):
|
||||
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
||||
|
||||
|
@ -419,6 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
# input_ids = input_ids[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -480,5 +480,5 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states, input_ids)
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
Loading…
Reference in New Issue
Block a user