text-generation-inference/server/text_generation_server/utils/medusa.py

52 lines
1.6 KiB
Python
Raw Normal View History

2023-09-11 22:12:19 +00:00
import torch
2023-09-18 13:37:59 +00:00
from dataclasses import dataclass
2023-09-11 22:12:19 +00:00
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
2023-09-18 13:37:59 +00:00
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
2023-09-11 22:12:19 +00:00
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(
self,
config,
2023-09-18 13:37:59 +00:00
weights,
lm_head
2023-09-11 22:12:19 +00:00
):
super().__init__()
self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
)
2023-09-18 13:37:59 +00:00
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return Output(logits=logits, speculative_logits=speculative_logits)
2023-09-11 22:12:19 +00:00
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x