mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Remove dead file.
This commit is contained in:
parent
c7793235d0
commit
1445b9517d
@ -1,53 +0,0 @@
|
|||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(torch.nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = FastLinear.load(
|
|
||||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
|
||||||
)
|
|
||||||
self.act = torch.nn.SiLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x + self.act(self.linear(x))
|
|
||||||
|
|
||||||
|
|
||||||
class MedusaModel(torch.nn.Module):
|
|
||||||
def __init__(self, config, weights, lm_head):
|
|
||||||
super().__init__()
|
|
||||||
self.heads = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
|
||||||
for i in range(config["medusa_num_heads"])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.lm_head = lm_head
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
logits = self.lm_head(x)
|
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
|
||||||
|
|
||||||
class MedusaHead(torch.nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.blocks = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
|
||||||
for i in range(config["medusa_num_layers"])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
n = len(self.blocks)
|
|
||||||
self.out = FastLinear.load(
|
|
||||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
x = self.out(x)
|
|
||||||
return x
|
|
Loading…
Reference in New Issue
Block a user