diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index 2238631b..2357564e 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -218,7 +218,7 @@ def triton_copy_next_input_ids_inplace( mask = (next_input_ids_start + block_arange) < next_input_ids_end # Mask values for request still prefilling - decode_mask = (cache_length + input_length + block_arange) > prompt_length + decode_mask = (cache_length + input_length + block_arange) >= prompt_length mask = mask & decode_mask @@ -229,7 +229,11 @@ def triton_copy_next_input_ids_inplace( # Store in all_input_ids, since it is a 2D tensor, apply stride * bid tl.store( - all_input_ids_ptr + stride_all_input_ids * bid + cache_length + block_arange, + all_input_ids_ptr + + stride_all_input_ids * bid + + cache_length + + input_length + + block_arange, next_input_ids, mask=mask, )