mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
* test(ctest) enable address sanitizer * feat(trtllm): expose finish reason to Rust * feat(trtllm): fix logits retrieval * misc(ci): enabe building tensorrt-llm * misc(ci): update Rust action toolchain * misc(ci): let's try to build the Dockerfile for trtllm # Conflicts: # Dockerfile_trtllm * misc(ci): provide mecanism to cache inside container * misc(ci): export aws creds as output of step * misc(ci): let's try this way * misc(ci): again * misc(ci): again * misc(ci): add debug profile * misc(ci): add debug profile * misc(ci): lets actually use sccache ... * misc(ci): do not build with ssl enabled * misc(ci): WAT * misc(ci): WAT * misc(ci): WAT * misc(ci): WAT * misc(ci): WAT * misc(backend): test with TGI S3 conf * misc(backend): test with TGI S3 conf * misc(backend): once more? * misc(backend): let's try with GHA * misc(backend): missing env directive * misc(backend): make sure to correctly set IS_GHA_BUILD=true in wf * misc(backend): ok let's debug smtg * misc(backend): WWWWWWWWWWWWWAAAAAAAA * misc(backend): kthxbye retry s3 * misc(backend): use session token * misc(backend): add more info * misc(backend): lets try 1h30 * misc(backend): lets try 1h30 * misc(backend): increase to 2h * misc(backend): lets try... * misc(backend): lets try... * misc(backend): let's build for ci-runtime * misc(backend): let's add some more tooling * misc(backend): add some tags * misc(backend): disable Werror for now * misc(backend): added automatic gha detection * misc(backend): remove leak sanitizer which is included in asan * misc(backend): forward env * misc(backend): forward env * misc(backend): let's try * misc(backend): let's try * misc(backend): again * misc(backend): again * misc(backend): again * misc(backend): again * misc(backend): again * misc(backend): fix sscache -> sccache * misc(backend): fix sscache -> sccache * misc(backend): fix sscache -> sccache * misc(backend): let's actually cache things now * misc(backend): let's actually cache things now * misc(backend): attempt to run the testS? * misc(backend): attempt to run the tests? * misc(backend): attempt to run the tests? * change runner size * fix: Correctly tag docker images (#2878) * fix: Correctly tag docker images * fix: Correctly tag docker images * misc(llamacpp): maybe? * misc(llamacpp): maybe? * misc(llamacpp): maybe? * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): go * misc(ci): go * misc(ci): go * misc(ci): use bin folder * misc(ci): make the wf callable for reuse * misc(ci): make the wf callable for reuse (bis) * misc(ci): make the wf callable for reuse (bis) * misc(ci): give the wf a name * Create test-trtllm.yml * Update test-trtllm.yml * Create build-trtllm2 * Rename build-trtllm2 to 1-build-trtllm2 * Rename test-trtllm.yml to 1-test-trtllm2.yml * misc(ci): fw secrets * Update 1-test-trtllm2.yml * Rename 1-build-trtllm2 to 1-build-trtllm2.yml * Update 1-test-trtllm2.yml * misc(ci): use ci-build.yaml as main dispatcher * Delete .github/workflows/1-test-trtllm2.yml * Delete .github/workflows/1-build-trtllm2.yml * misc(ci): rights? * misc(ci): rights? * misc(ci): once more? * misc(ci): once more? * misc(ci): baby more time? * misc(ci): baby more time? * misc(ci): try the permission above again? * misc(ci): try the permission above again? * misc(ci): try the permission scoped again? * misc(ci): install tensorrt_llm_executor_static * misc(ci): attempt to rebuild with sccache? * misc(ci):run the tests on GPU instance * misc(ci): let's actually setup sccache in the build.rs * misc(ci): reintroduce variables * misc(ci): enforce sccache * misc(ci): correct right job name dependency * misc(ci): detect dev profile for debug * misc(ci): detect gha build * misc(ci): detect gha build * misc(ci): ok debug * misc(ci): wtf * misc(ci): wtf2 * misc(ci): wtf3 * misc(ci): use commit HEAD instead of merge commit for image id * misc(ci): wtfinfini * misc(ci): wtfinfini * misc(ci): KAMEHAMEHA * Merge TRTLLM in standard CI * misc(ci): remove input machine * misc(ci): missing id-token for AWS auth * misc(ci): missing id-token for AWS auth * misc(ci): missing id-token for AWS auth * misc(ci): again... * misc(ci): again... * misc(ci): again... * misc(ci): again... * misc(ci): missing benchmark * misc(ci): missing backends * misc(ci): missing launcher * misc(ci): give everything aws needs * misc(ci): give everything aws needs * misc(ci): fix warnings * misc(ci): attempt to fix sccache not building trtllm * misc(ci): attempt to fix sccache not building trtllm again --------- Co-authored-by: Guillaume LEGENDRE <glegendre01@gmail.com> Co-authored-by: Hugo Larcher <hugo.larcher@huggingface.co> Co-authored-by: Pauline Bailly-Masson <155966238+paulinebm@users.noreply.github.com>
104 lines
3.1 KiB
Rust
104 lines
3.1 KiB
Rust
pub use looper::TensorRtLlmBackendV2;
|
|
|
|
pub mod errors;
|
|
mod looper;
|
|
mod utils;
|
|
|
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
|
|
mod ffi {
|
|
#[cxx_name = "finish_reason_t"]
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub enum FinishReason {
|
|
/// The request is not finished.
|
|
#[cxx_name = "kNOT_FINISHED"]
|
|
NotFinished = 0u8,
|
|
|
|
/// The request finished because the end id was generated.
|
|
#[cxx_name = "kEND_ID"]
|
|
EndTokenId = 1u8,
|
|
|
|
/// The request finished because a stop word was generated.
|
|
#[cxx_name = "kSTOP_WORDS"]
|
|
StopWords = 2u8,
|
|
|
|
/// The request finished because the maximum number of tokens was reached.
|
|
#[cxx_name = "kLENGTH"]
|
|
MaxLength = 3u8,
|
|
}
|
|
|
|
/// Struct used as shared type between rust and C++ to represent the result
|
|
/// of a single decoding iteration
|
|
#[cxx_name = "generation_step_t"]
|
|
#[derive(Debug, Clone)]
|
|
pub struct GenerationStep {
|
|
request_id: u64,
|
|
token_id: u32,
|
|
log_prob: f32,
|
|
is_final: bool,
|
|
finish_reason: FinishReason,
|
|
has_error: bool,
|
|
error_msg: String,
|
|
}
|
|
|
|
unsafe extern "C++" {
|
|
include!("backends/trtllm/csrc/ffi.hpp");
|
|
|
|
/// Represent an instance of the underlying TensorRT-LLM backend
|
|
#[cxx_name = "tensorrt_llm_backend_t"]
|
|
type TensorRtLlmBackendImpl;
|
|
|
|
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
|
|
/// * `executor_worker`: Path to the TRTLLM executor worker
|
|
///
|
|
/// returns: <unknown>
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
///
|
|
/// ```
|
|
fn create_backend_from_engine_folder(
|
|
engine_folder: &str,
|
|
executor_worker: &str,
|
|
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
|
|
|
|
fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize;
|
|
|
|
fn submit(
|
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
|
tokens: &[u32],
|
|
max_new_tokens: u32,
|
|
top_k: u32,
|
|
top_p: f32,
|
|
temperature: f32,
|
|
repetition_penalty: f32,
|
|
frequency_penalty: f32,
|
|
seed: u64,
|
|
) -> Result<u64>;
|
|
|
|
fn pull_tokens(
|
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
|
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
|
|
|
|
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
|
|
}
|
|
}
|
|
|
|
use ffi::FinishReason;
|
|
use text_generation_router::FinishReason as InferFinishReason;
|
|
|
|
impl From<FinishReason> for InferFinishReason {
|
|
fn from(reason: FinishReason) -> Self {
|
|
match reason {
|
|
FinishReason::StopWords => InferFinishReason::StopSequence,
|
|
FinishReason::MaxLength => InferFinishReason::Length,
|
|
FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
|
|
_ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
|
|
}
|
|
}
|
|
}
|