diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4cc2245d..2c12984b 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -557,16 +557,12 @@ class MedusaHeadV2(nn.Module): x_block = x[:, start:stop] # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1 - medusa_res = torch.movedim( - self.act(self.linear(x)).reshape( - *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] - ), - -2, - 1, - ).contiguous() + medusa_res = self.act(self.linear(x)).reshape( + *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] + ) # Apply all residual medusa heads - output = x[:, start:stop].unsqueeze(1) + medusa_res + output = x[:, start:stop].unsqueeze(-2) + medusa_res # Gather medusa heads world_output = [ @@ -576,17 +572,17 @@ class MedusaHeadV2(nn.Module): world_output = torch.cat(world_output, dim=-1) # Stack x and medusa residual x - stacked_x = torch.cat([x.unsqueeze(1), world_output], dim=1) + stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2) # Compute lm head on x + medusa residual x logits = self.lm_head(stacked_x) # Finally, split logits from speculative logits logits, speculative_logits = torch.split( - logits, [1, self.n_medusa_heads], dim=1 + logits, [1, self.n_medusa_heads], dim=-2 ) # Squeeze added dimension - logits = logits.squeeze(1) + logits = logits.squeeze(-2) return logits, speculative_logits