diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 7b60593d..d6acafa1 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -62,5 +62,7 @@ mod ffi { fn pull_tokens( self: Pin<&mut TensorRtLlmBackendImpl>, ) -> Result>>; + + fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); } } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 8e9ff49d..0a4499b5 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -121,7 +121,8 @@ fn executor_status_looper( } if backend.num_tokens_ready() > 0 { - match backend.pin_mut().pull_tokens() { + let backend = backend.pin_mut(); + match backend.pull_tokens() { Ok(responses) => { // Iterate through all the decoded token for step in responses.deref() { @@ -140,6 +141,7 @@ fn executor_status_looper( if posted.is_err() || step.is_final { debug!("Removing {}", step.request_id); + backend.cancel(step.request_id); let _ = in_flights.remove(&step.request_id); } } else {