(ffi) fix usage of wrong vector constructor making a capacity fill call

This commit is contained in:
Morgan Funtowicz 2024-08-09 22:45:18 +02:00
parent dddc9a44bd
commit 8e648ce425
3 changed files with 72 additions and 52 deletions

View File

@ -64,8 +64,6 @@ namespace huggingface::tgi::backends {
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
};
GenerationStep ConvertResponseToGenerationStep(const tle::Response &response);
/***
*
* @param engineFolder

View File

@ -36,34 +36,38 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
const auto responses = TensorRtLlmBackend::PullNewTokens();
auto steps = std::make_unique<std::vector<GenerationStep>>(responses.size());
std::ranges::copy(std::views::transform(responses, ConvertResponseToGenerationStep), std::back_inserter(*steps));
return steps;
}
huggingface::tgi::backends::GenerationStep
huggingface::tgi::backends::ConvertResponseToGenerationStep(const tle::Response &response) {
const auto reqId = response.getRequestId();
if (!response.hasError()) {
const auto result = response.getResult();
return std::move(GenerationStep{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
});
} else {
return std::move(GenerationStep{
reqId,
0,
0.0,
true,
true,
std::move(response.getErrorMsg())
});
}
auto steps = std::make_unique<std::vector<GenerationStep>>();
steps->reserve(responses.size());
SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
// Transform tle::Response to GenerationStep
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [&](const Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) {
const auto result = r.getResult();
return GenerationStep{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
};
} else {
return GenerationStep{
reqId,
0,
0.0,
true,
true,
std::move(r.getErrorMsg())
};
}
});
return steps;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>

View File

@ -8,18 +8,20 @@ use cxx::UniquePtr;
use hashbrown::HashMap;
use log::warn;
use tokenizers::{Encoding, Tokenizer};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::mpsc::error::SendError;
use tokio::task::{JoinHandle, spawn_blocking};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info, Level, span};
use tracing::{debug, error, info, span, Level};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::infer::InferError::GenerationError;
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
@ -95,6 +97,8 @@ fn executor_status_poller(
if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() {
Ok(responses) => {
debug!("Received {} tokens from the executor", responses.len());
// worse case scenario is one token for each response: with_capacity(responses.len())
// grouper will group decoded tokens per request to decode multiple tokens
let mut grouper: HashMap<u64, DecodedTokenContext> =
@ -102,33 +106,49 @@ fn executor_status_poller(
// Iterate through all the decoded token
for step in responses.deref() {
let request_id = step.request_id;
match in_flights.get(&request_id) {
match in_flights.get(&step.request_id) {
Some(ctx) => {
info!("New token for {} -> {}", request_id, step.token_id);
debug!(
"{} -> (token={}, final={})",
step.request_id, step.token_id, step.is_final
);
// If no error, let's forward to post-processor
if !step.has_error {
let req_group = grouper.entry(request_id).or_insert(
let req_group = grouper.entry(step.request_id).or_insert(
DecodedTokenContext {
tokens: vec![],
ctx: ctx.streamer.clone(), // Arc::clone() = cheap
},
);
req_group.tokens.push(step.clone()); // Should be ultra cheap
if step.is_final {
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!(
"Error for request: {} -> {}",
request_id, &step.error_msg
step.request_id, &step.error_msg
);
// TODO: Send something back to the postprocessor for the client?
}
// Remove from tracked requests
if step.is_final {
let _ = in_flights.remove(&step.request_id);
}
}
None => {
error!("Got step for untracked request {}", request_id);
if step.has_error {
error!(
"Untracked request {} -> {}",
step.request_id, &step.error_msg
);
continue;
} else {
error!(
"Got step for untracked request {}",
step.request_id
);
}
}
}
}
@ -275,18 +295,16 @@ impl TensorRtLlmBackendV2 {
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
return Err(InferError::ValidationError(TopNTokensDisabled));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
return Err(InferError::ValidationError(Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
0 => Err(InferError::ValidationError(EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),