mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Rebase after refactor.
This commit is contained in:
parent
b884899086
commit
71a535e401
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@ -6,6 +7,7 @@ from text_generation_server.layers import TensorParallelEmbedding, FastLinear
|
|||||||
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorLayerNorm(nn.Module):
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
A L2 normalization implementation
|
A L2 normalization implementation
|
||||||
@ -140,7 +142,7 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
self.mlp_speculator = mlp_speculator
|
self.mlp_speculator = mlp_speculator
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor, input_ids: torch.Tensor
|
self, input: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
logits = self.lm_head(input)
|
logits = self.lm_head(input)
|
||||||
# If we have too many tokens, we skip speculative logits
|
# If we have too many tokens, we skip speculative logits
|
||||||
@ -172,4 +174,3 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
||||||
|
|
||||||
|
@ -419,6 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
# input_ids = input_ids[lm_head_indices]
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -480,5 +480,5 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states, input_ids)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
Loading…
Reference in New Issue
Block a user