fix kernel

This commit is contained in:
OlivierDehaene 2024-10-25 10:37:10 +02:00
parent 347f3f51da
commit a7465ba67d
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -218,7 +218,7 @@ def triton_copy_next_input_ids_inplace(
mask = (next_input_ids_start + block_arange) < next_input_ids_end
# Mask values for request still prefilling
decode_mask = (cache_length + input_length + block_arange) > prompt_length
decode_mask = (cache_length + input_length + block_arange) >= prompt_length
mask = mask & decode_mask
@ -229,7 +229,11 @@ def triton_copy_next_input_ids_inplace(
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
tl.store(
all_input_ids_ptr + stride_all_input_ids * bid + cache_length + block_arange,
all_input_ids_ptr
+ stride_all_input_ids * bid
+ cache_length
+ input_length
+ block_arange,
next_input_ids,
mask=mask,
)