Dummy fix for medusa.

This commit is contained in:
Nicolas Patry 2024-04-12 10:12:09 +00:00
parent b24bdb9f8c
commit 10dd0150c0
2 changed files with 35 additions and 25 deletions

View File

@ -874,21 +874,15 @@ class FlashCausalLM(Model):
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs graphs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if bs == 3: if graphs:
padded_bs = 4 cuda_graph = graphs[0]
elif 3 < bs <= 8: else:
padded_bs = 8 cuda_graph = None
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if ( if (
cu_seqlen_prefill is not None cu_seqlen_prefill is not None
or cuda_graph is None or cuda_graph is None
or batch.speculative_ids is not None
): ):
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -980,6 +974,9 @@ class FlashCausalLM(Model):
speculative_logits, speculative_logits,
) )
if os.getenv("RANK") == "0":
logger.info(f"Accepted ids {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
) )

View File

@ -412,48 +412,61 @@ class SuperLayer(nn.Module):
class ResBlock(torch.nn.Module): class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.linear = FastLinear.load( self.linear = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True config, prefix=f"{prefix}.linear", weights=weights, bias=True
) )
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
self.act = torch.nn.SiLU() self.act = torch.nn.SiLU()
def forward(self, x): def forward(self, x):
return x + self.act(self.linear(x)) size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size
output = x[:, start: stop] + self.act(self.linear(x))
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class MedusaModel(torch.nn.Module): class MedusaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, medusa_config, weights, lm_head):
super().__init__() super().__init__()
self.heads = torch.nn.ModuleList( self.heads = torch.nn.ModuleList(
[ [
MedusaHead(config, prefix=f"{i}", weights=weights) MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"]) for i in range(medusa_config["medusa_num_heads"])
] ]
) )
self.lm_head = lm_head
def forward(self, x): def forward(self, x):
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) speculative_hidden_states = torch.stack([head(x) for head in self.heads], dim=1)
speculative_logits = self.lm_head(speculative_hidden_states)
return speculative_logits return speculative_logits
class MedusaHead(torch.nn.Module): class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, medusa_config, prefix, weights):
super().__init__() super().__init__()
self.blocks = torch.nn.ModuleList( self.blocks = torch.nn.ModuleList(
[ [
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"]) for i in range(medusa_config["medusa_num_layers"])
] ]
) )
n = len(self.blocks) n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
x = self.out(x)
return x return x
@ -476,7 +489,7 @@ class SpeculativeHead(nn.Module):
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) medusa_config = json.load(f)
routing = weights.routing routing = weights.routing
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
for k in f.keys(): for k in f.keys():
@ -486,7 +499,7 @@ class SpeculativeHead(nn.Module):
) )
weights.routing[k] = filename weights.routing[k] = filename
medusa = MedusaModel(config, weights) medusa = MedusaModel(config, medusa_config, weights, lm_head)
else: else:
medusa = None medusa = None
return SpeculativeHead(lm_head, medusa) return SpeculativeHead(lm_head, medusa)