From 1445b9517df85e24d3f19209248e9f67df38f978 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Feb 2024 15:15:02 +0100 Subject: [PATCH] Remove dead file. --- server/text_generation_server/utils/medusa.py | 53 ------------------- 1 file changed, 53 deletions(-) delete mode 100644 server/text_generation_server/utils/medusa.py diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py deleted file mode 100644 index 9f66ba10..00000000 --- a/server/text_generation_server/utils/medusa.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -from dataclasses import dataclass -from text_generation_server.utils.layers import TensorParallelHead, FastLinear - - -class ResBlock(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.linear = FastLinear.load( - config, prefix=f"{prefix}.linear", weights=weights, bias=True - ) - self.act = torch.nn.SiLU() - - def forward(self, x): - return x + self.act(self.linear(x)) - - -class MedusaModel(torch.nn.Module): - def __init__(self, config, weights, lm_head): - super().__init__() - self.heads = torch.nn.ModuleList( - [ - MedusaHead(config, prefix=f"{i}", weights=weights) - for i in range(config["medusa_num_heads"]) - ] - ) - self.lm_head = lm_head - - def forward(self, x): - logits = self.lm_head(x) - speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) - return logits, speculative_logits - - -class MedusaHead(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.blocks = torch.nn.ModuleList( - [ - ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) - for i in range(config["medusa_num_layers"]) - ] - ) - n = len(self.blocks) - self.out = FastLinear.load( - config, prefix=f"{prefix}.{n}", weights=weights, bias=False - ) - - def forward(self, x): - for block in self.blocks: - x = block(x) - x = self.out(x) - return x