Dummy fix for medusa.

This commit is contained in:
Nicolas Patry 2024-04-12 10:12:09 +00:00
parent b24bdb9f8c
commit 10dd0150c0
2 changed files with 35 additions and 25 deletions

View File

@ -874,21 +874,15 @@ class FlashCausalLM(Model):
lm_head_indices = batch.prefill_head_indices
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
graphs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if graphs:
cuda_graph = graphs[0]
else:
cuda_graph = None
if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
return self.model.forward(
input_ids=input_ids,
@ -980,6 +974,9 @@ class FlashCausalLM(Model):
speculative_logits,
)
if os.getenv("RANK") == "0":
logger.info(f"Accepted ids {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)

View File

@ -412,48 +412,61 @@ class SuperLayer(nn.Module):
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
self.linear = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size
output = x[:, start: stop] + self.act(self.linear(x))
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, config, medusa_config, weights, lm_head):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"])
]
)
self.lm_head = lm_head
def forward(self, x):
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
speculative_hidden_states = torch.stack([head(x) for head in self.heads], dim=1)
speculative_logits = self.lm_head(speculative_hidden_states)
return speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, medusa_config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
for i in range(medusa_config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
@ -476,7 +489,7 @@ class SpeculativeHead(nn.Module):
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
@ -486,7 +499,7 @@ class SpeculativeHead(nn.Module):
)
weights.routing[k] = filename
medusa = MedusaModel(config, weights)
medusa = MedusaModel(config, medusa_config, weights, lm_head)
else:
medusa = None
return SpeculativeHead(lm_head, medusa)