remove movedim

This commit is contained in:
OlivierDehaene 2024-04-12 16:23:54 +02:00
parent 68717f8716
commit 0dd617b822

View File

@ -557,16 +557,12 @@ class MedusaHeadV2(nn.Module):
x_block = x[:, start:stop]
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = torch.movedim(
self.act(self.linear(x)).reshape(
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
),
-2,
1,
).contiguous()
medusa_res = self.act(self.linear(x)).reshape(
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
)
# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(1) + medusa_res
output = x[:, start:stop].unsqueeze(-2) + medusa_res
# Gather medusa heads
world_output = [
@ -576,17 +572,17 @@ class MedusaHeadV2(nn.Module):
world_output = torch.cat(world_output, dim=-1)
# Stack x and medusa residual x
stacked_x = torch.cat([x.unsqueeze(1), world_output], dim=1)
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
# Compute lm head on x + medusa residual x
logits = self.lm_head(stacked_x)
# Finally, split logits from speculative logits
logits, speculative_logits = torch.split(
logits, [1, self.n_medusa_heads], dim=1
logits, [1, self.n_medusa_heads], dim=-2
)
# Squeeze added dimension
logits = logits.squeeze(1)
logits = logits.squeeze(-2)
return logits, speculative_logits