Fixing linting on main.

This commit is contained in:
Nicolas Patry 2024-11-04 15:10:26 +01:00
parent aadc9cb485
commit b81231c790
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -1729,9 +1729,11 @@ class FlashCausalLM(Model):
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
slot_indices = (
batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
slots = batch.slots[slot_indices]
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)