fix kernel

This commit is contained in:
OlivierDehaene 2024-10-24 19:17:31 +02:00
parent d1e95ceaff
commit 347f3f51da
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -229,11 +229,7 @@ def triton_copy_next_input_ids_inplace(
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
tl.store(
all_input_ids_ptr
+ stride_all_input_ids * bid
+ cache_length
+ input_length
+ block_arange,
all_input_ids_ptr + stride_all_input_ids * bid + cache_length + block_arange,
next_input_ids,
mask=mask,
)