diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py deleted file mode 100644 index 9f66ba10..00000000 --- a/server/text_generation_server/utils/medusa.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -from dataclasses import dataclass -from text_generation_server.utils.layers import TensorParallelHead, FastLinear - - -class ResBlock(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.linear = FastLinear.load( - config, prefix=f"{prefix}.linear", weights=weights, bias=True - ) - self.act = torch.nn.SiLU() - - def forward(self, x): - return x + self.act(self.linear(x)) - - -class MedusaModel(torch.nn.Module): - def __init__(self, config, weights, lm_head): - super().__init__() - self.heads = torch.nn.ModuleList( - [ - MedusaHead(config, prefix=f"{i}", weights=weights) - for i in range(config["medusa_num_heads"]) - ] - ) - self.lm_head = lm_head - - def forward(self, x): - logits = self.lm_head(x) - speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) - return logits, speculative_logits - - -class MedusaHead(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.blocks = torch.nn.ModuleList( - [ - ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) - for i in range(config["medusa_num_layers"]) - ] - ) - n = len(self.blocks) - self.out = FastLinear.load( - config, prefix=f"{prefix}.{n}", weights=weights, bias=False - ) - - def forward(self, x): - for block in self.blocks: - x = block(x) - x = self.out(x) - return x