mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix kernel
This commit is contained in:
parent
347f3f51da
commit
a7465ba67d
@ -218,7 +218,7 @@ def triton_copy_next_input_ids_inplace(
|
|||||||
mask = (next_input_ids_start + block_arange) < next_input_ids_end
|
mask = (next_input_ids_start + block_arange) < next_input_ids_end
|
||||||
|
|
||||||
# Mask values for request still prefilling
|
# Mask values for request still prefilling
|
||||||
decode_mask = (cache_length + input_length + block_arange) > prompt_length
|
decode_mask = (cache_length + input_length + block_arange) >= prompt_length
|
||||||
|
|
||||||
mask = mask & decode_mask
|
mask = mask & decode_mask
|
||||||
|
|
||||||
@ -229,7 +229,11 @@ def triton_copy_next_input_ids_inplace(
|
|||||||
|
|
||||||
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
|
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
|
||||||
tl.store(
|
tl.store(
|
||||||
all_input_ids_ptr + stride_all_input_ids * bid + cache_length + block_arange,
|
all_input_ids_ptr
|
||||||
|
+ stride_all_input_ids * bid
|
||||||
|
+ cache_length
|
||||||
|
+ input_length
|
||||||
|
+ block_arange,
|
||||||
next_input_ids,
|
next_input_ids,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user