mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
remove movedim
This commit is contained in:
parent
68717f8716
commit
0dd617b822
@ -557,16 +557,12 @@ class MedusaHeadV2(nn.Module):
|
||||
x_block = x[:, start:stop]
|
||||
|
||||
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||
medusa_res = torch.movedim(
|
||||
self.act(self.linear(x)).reshape(
|
||||
medusa_res = self.act(self.linear(x)).reshape(
|
||||
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||
),
|
||||
-2,
|
||||
1,
|
||||
).contiguous()
|
||||
)
|
||||
|
||||
# Apply all residual medusa heads
|
||||
output = x[:, start:stop].unsqueeze(1) + medusa_res
|
||||
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||
|
||||
# Gather medusa heads
|
||||
world_output = [
|
||||
@ -576,17 +572,17 @@ class MedusaHeadV2(nn.Module):
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
|
||||
# Stack x and medusa residual x
|
||||
stacked_x = torch.cat([x.unsqueeze(1), world_output], dim=1)
|
||||
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||
|
||||
# Compute lm head on x + medusa residual x
|
||||
logits = self.lm_head(stacked_x)
|
||||
|
||||
# Finally, split logits from speculative logits
|
||||
logits, speculative_logits = torch.split(
|
||||
logits, [1, self.n_medusa_heads], dim=1
|
||||
logits, [1, self.n_medusa_heads], dim=-2
|
||||
)
|
||||
# Squeeze added dimension
|
||||
logits = logits.squeeze(1)
|
||||
logits = logits.squeeze(-2)
|
||||
|
||||
return logits, speculative_logits
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user