Medusa requires reshaping.

This commit is contained in:
Nicolas Patry 2024-08-13 16:25:29 +02:00
parent 99b6b5c795
commit 4fff77ebcb
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -1412,7 +1412,7 @@ class FlashCausalLM(Model):
).view(-1) ).view(-1)
prefix_lens_tensor = ( prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).view(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = ( block_tables = (