mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
* test(ctest) enable address sanitizer * feat(trtllm): expose finish reason to Rust * feat(trtllm): fix logits retrieval * misc(ci): enabe building tensorrt-llm * misc(ci): update Rust action toolchain * misc(ci): let's try to build the Dockerfile for trtllm # Conflicts: # Dockerfile_trtllm * misc(ci): provide mecanism to cache inside container * misc(ci): export aws creds as output of step * misc(ci): let's try this way * misc(ci): again * misc(ci): again * misc(ci): add debug profile * misc(ci): add debug profile * misc(ci): lets actually use sccache ... * misc(ci): do not build with ssl enabled * misc(ci): WAT * misc(ci): WAT * misc(ci): WAT * misc(ci): WAT * misc(ci): WAT * misc(backend): test with TGI S3 conf * misc(backend): test with TGI S3 conf * misc(backend): once more? * misc(backend): let's try with GHA * misc(backend): missing env directive * misc(backend): make sure to correctly set IS_GHA_BUILD=true in wf * misc(backend): ok let's debug smtg * misc(backend): WWWWWWWWWWWWWAAAAAAAA * misc(backend): kthxbye retry s3 * misc(backend): use session token * misc(backend): add more info * misc(backend): lets try 1h30 * misc(backend): lets try 1h30 * misc(backend): increase to 2h * misc(backend): lets try... * misc(backend): lets try... * misc(backend): let's build for ci-runtime * misc(backend): let's add some more tooling * misc(backend): add some tags * misc(backend): disable Werror for now * misc(backend): added automatic gha detection * misc(backend): remove leak sanitizer which is included in asan * misc(backend): forward env * misc(backend): forward env * misc(backend): let's try * misc(backend): let's try * misc(backend): again * misc(backend): again * misc(backend): again * misc(backend): again * misc(backend): again * misc(backend): fix sscache -> sccache * misc(backend): fix sscache -> sccache * misc(backend): fix sscache -> sccache * misc(backend): let's actually cache things now * misc(backend): let's actually cache things now * misc(backend): attempt to run the testS? * misc(backend): attempt to run the tests? * misc(backend): attempt to run the tests? * change runner size * fix: Correctly tag docker images (#2878) * fix: Correctly tag docker images * fix: Correctly tag docker images * misc(llamacpp): maybe? * misc(llamacpp): maybe? * misc(llamacpp): maybe? * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): gogogo * misc(ci): go * misc(ci): go * misc(ci): go * misc(ci): use bin folder * misc(ci): make the wf callable for reuse * misc(ci): make the wf callable for reuse (bis) * misc(ci): make the wf callable for reuse (bis) * misc(ci): give the wf a name * Create test-trtllm.yml * Update test-trtllm.yml * Create build-trtllm2 * Rename build-trtllm2 to 1-build-trtllm2 * Rename test-trtllm.yml to 1-test-trtllm2.yml * misc(ci): fw secrets * Update 1-test-trtllm2.yml * Rename 1-build-trtllm2 to 1-build-trtllm2.yml * Update 1-test-trtllm2.yml * misc(ci): use ci-build.yaml as main dispatcher * Delete .github/workflows/1-test-trtllm2.yml * Delete .github/workflows/1-build-trtllm2.yml * misc(ci): rights? * misc(ci): rights? * misc(ci): once more? * misc(ci): once more? * misc(ci): baby more time? * misc(ci): baby more time? * misc(ci): try the permission above again? * misc(ci): try the permission above again? * misc(ci): try the permission scoped again? * misc(ci): install tensorrt_llm_executor_static * misc(ci): attempt to rebuild with sccache? * misc(ci):run the tests on GPU instance * misc(ci): let's actually setup sccache in the build.rs * misc(ci): reintroduce variables * misc(ci): enforce sccache * misc(ci): correct right job name dependency * misc(ci): detect dev profile for debug * misc(ci): detect gha build * misc(ci): detect gha build * misc(ci): ok debug * misc(ci): wtf * misc(ci): wtf2 * misc(ci): wtf3 * misc(ci): use commit HEAD instead of merge commit for image id * misc(ci): wtfinfini * misc(ci): wtfinfini * misc(ci): KAMEHAMEHA * Merge TRTLLM in standard CI * misc(ci): remove input machine * misc(ci): missing id-token for AWS auth * misc(ci): missing id-token for AWS auth * misc(ci): missing id-token for AWS auth * misc(ci): again... * misc(ci): again... * misc(ci): again... * misc(ci): again... * misc(ci): missing benchmark * misc(ci): missing backends * misc(ci): missing launcher * misc(ci): give everything aws needs * misc(ci): give everything aws needs * misc(ci): fix warnings * misc(ci): attempt to fix sccache not building trtllm * misc(ci): attempt to fix sccache not building trtllm again --------- Co-authored-by: Guillaume LEGENDRE <glegendre01@gmail.com> Co-authored-by: Hugo Larcher <hugo.larcher@huggingface.co> Co-authored-by: Pauline Bailly-Masson <155966238+paulinebm@users.noreply.github.com>
81 lines
3.5 KiB
C++
81 lines
3.5 KiB
C++
#include <ranges>
|
|
|
|
#include <nlohmann/json.hpp>
|
|
|
|
#include "backend.hpp"
|
|
#include "hardware.hpp"
|
|
|
|
namespace huggingface::tgi::backends::trtllm {
|
|
tle::ParallelConfig backend_workspace_t::parallel_config() const {
|
|
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
|
const auto world_size = config_["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
|
|
|
|
auto mode = tle::CommunicationMode::kLEADER;
|
|
std::optional<tle::OrchestratorConfig> orchestratorConfig = std::nullopt;
|
|
|
|
if (world_size > 1) {
|
|
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
|
mode = tle::CommunicationMode::kORCHESTRATOR;
|
|
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr,
|
|
true);
|
|
} else {
|
|
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
|
}
|
|
|
|
return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
|
|
}
|
|
|
|
|
|
tle::ExecutorConfig backend_workspace_t::executor_config() const {
|
|
// Retrieve the compute capabilities to enable some options at runtime
|
|
const auto compute_capabilities = hardware::cuda::compute_capabilities_t();
|
|
|
|
// Allocate the config
|
|
tle::ExecutorConfig executor_config(/* maxBeamWidth = */ 1);
|
|
|
|
// Set the parallel config as inferred
|
|
executor_config.setParallelConfig(parallel_config());
|
|
|
|
// Define some configuration variables
|
|
executor_config.setKvCacheConfig(tle::KvCacheConfig(true));
|
|
executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere());
|
|
executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
|
|
return executor_config;
|
|
}
|
|
|
|
backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path)
|
|
: workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {}
|
|
|
|
size_t backend_t::num_tokens_ready() const noexcept {
|
|
return executor_.getNumResponsesReady();
|
|
}
|
|
|
|
std::expected<request_id_t, backend_error_t>
|
|
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,
|
|
const sampling_params_t s_params) noexcept {
|
|
SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
|
|
return executor_.enqueueRequest(tle::Request{
|
|
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
|
|
static_cast<tle::SizeType32>(g_params.max_new_tokens),
|
|
true,
|
|
(tle::SamplingConfig) s_params,
|
|
tle::OutputConfig{ /* returnLogProbs= */ true},
|
|
std::nullopt,
|
|
std::nullopt,
|
|
std::nullopt,
|
|
std::nullopt,
|
|
workspace.generation_config().stop_words
|
|
});
|
|
}
|
|
|
|
std::vector<tle::Response> backend_t::pull_tokens() noexcept {
|
|
SPDLOG_TRACE(FMT_STRING("Pulling out tokens ({:d} available)"), num_tokens_ready());
|
|
return executor_.awaitResponses();
|
|
}
|
|
|
|
void backend_t::cancel(request_id_t request_id) noexcept {
|
|
SPDLOG_TRACE(FMT_STRING("Cancelling request: {:d}"), request_id);
|
|
executor_.cancelRequest(request_id);
|
|
}
|
|
}
|