From 0dd617b8229ecd801b50e74feadafaab816844be Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:23:54 +0200 Subject: [PATCH] remove movedim --- server/text_generation_server/utils/layers.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4cc2245d..2c12984b 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -557,16 +557,12 @@ class MedusaHeadV2(nn.Module): x_block = x[:, start:stop] # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1 - medusa_res = torch.movedim( - self.act(self.linear(x)).reshape( - *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] - ), - -2, - 1, - ).contiguous() + medusa_res = self.act(self.linear(x)).reshape( + *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] + ) # Apply all residual medusa heads - output = x[:, start:stop].unsqueeze(1) + medusa_res + output = x[:, start:stop].unsqueeze(-2) + medusa_res # Gather medusa heads world_output = [ @@ -576,17 +572,17 @@ class MedusaHeadV2(nn.Module): world_output = torch.cat(world_output, dim=-1) # Stack x and medusa residual x - stacked_x = torch.cat([x.unsqueeze(1), world_output], dim=1) + stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2) # Compute lm head on x + medusa residual x logits = self.lm_head(stacked_x) # Finally, split logits from speculative logits logits, speculative_logits = torch.split( - logits, [1, self.n_medusa_heads], dim=1 + logits, [1, self.n_medusa_heads], dim=-2 ) # Squeeze added dimension - logits = logits.squeeze(1) + logits = logits.squeeze(-2) return logits, speculative_logits