diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index bf6dd252..2238631b 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -229,11 +229,7 @@ def triton_copy_next_input_ids_inplace( # Store in all_input_ids, since it is a 2D tensor, apply stride * bid tl.store( - all_input_ids_ptr - + stride_all_input_ids * bid - + cache_length - + input_length - + block_arange, + all_input_ids_ptr + stride_all_input_ids * bid + cache_length + block_arange, next_input_ids, mask=mask, )