mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
(ffi) fix usage of wrong vector constructor making a capacity fill call
This commit is contained in:
parent
dddc9a44bd
commit
8e648ce425
@ -64,8 +64,6 @@ namespace huggingface::tgi::backends {
|
|||||||
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
|
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
|
||||||
};
|
};
|
||||||
|
|
||||||
GenerationStep ConvertResponseToGenerationStep(const tle::Response &response);
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param engineFolder
|
* @param engineFolder
|
||||||
|
@ -36,34 +36,38 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
|||||||
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
|
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
|
||||||
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
|
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
|
||||||
const auto responses = TensorRtLlmBackend::PullNewTokens();
|
const auto responses = TensorRtLlmBackend::PullNewTokens();
|
||||||
auto steps = std::make_unique<std::vector<GenerationStep>>(responses.size());
|
|
||||||
std::ranges::copy(std::views::transform(responses, ConvertResponseToGenerationStep), std::back_inserter(*steps));
|
|
||||||
return steps;
|
|
||||||
}
|
|
||||||
|
|
||||||
huggingface::tgi::backends::GenerationStep
|
auto steps = std::make_unique<std::vector<GenerationStep>>();
|
||||||
huggingface::tgi::backends::ConvertResponseToGenerationStep(const tle::Response &response) {
|
steps->reserve(responses.size());
|
||||||
const auto reqId = response.getRequestId();
|
|
||||||
if (!response.hasError()) {
|
SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
|
||||||
const auto result = response.getResult();
|
|
||||||
return std::move(GenerationStep{
|
// Transform tle::Response to GenerationStep
|
||||||
|
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [&](const Response &r) {
|
||||||
|
const auto reqId = r.getRequestId();
|
||||||
|
if (!r.hasError()) {
|
||||||
|
const auto result = r.getResult();
|
||||||
|
return GenerationStep{
|
||||||
reqId,
|
reqId,
|
||||||
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
||||||
result.logProbs.value()[0][0],
|
result.logProbs.value()[0][0],
|
||||||
result.isFinal,
|
result.isFinal,
|
||||||
false,
|
false,
|
||||||
std::string()
|
std::string()
|
||||||
});
|
};
|
||||||
} else {
|
} else {
|
||||||
return std::move(GenerationStep{
|
return GenerationStep{
|
||||||
reqId,
|
reqId,
|
||||||
0,
|
0,
|
||||||
0.0,
|
0.0,
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
std::move(response.getErrorMsg())
|
std::move(r.getErrorMsg())
|
||||||
});
|
};
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return steps;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||||
|
@ -8,18 +8,20 @@ use cxx::UniquePtr;
|
|||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
use log::warn;
|
use log::warn;
|
||||||
use tokenizers::{Encoding, Tokenizer};
|
use tokenizers::{Encoding, Tokenizer};
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::task::{JoinHandle, spawn_blocking};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
|
use tokio::task::{spawn_blocking, JoinHandle};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{error, info, Level, span};
|
use tracing::{debug, error, info, span, Level};
|
||||||
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::infer::InferError::GenerationError;
|
use text_generation_router::validation::ValidationError::{
|
||||||
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
};
|
||||||
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||||
|
use text_generation_router::{FinishReason, Token};
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
@ -95,6 +97,8 @@ fn executor_status_poller(
|
|||||||
if backend.num_responses_ready() > 0 {
|
if backend.num_responses_ready() > 0 {
|
||||||
match backend.pin_mut().pull_tokens() {
|
match backend.pin_mut().pull_tokens() {
|
||||||
Ok(responses) => {
|
Ok(responses) => {
|
||||||
|
debug!("Received {} tokens from the executor", responses.len());
|
||||||
|
|
||||||
// worse case scenario is one token for each response: with_capacity(responses.len())
|
// worse case scenario is one token for each response: with_capacity(responses.len())
|
||||||
// grouper will group decoded tokens per request to decode multiple tokens
|
// grouper will group decoded tokens per request to decode multiple tokens
|
||||||
let mut grouper: HashMap<u64, DecodedTokenContext> =
|
let mut grouper: HashMap<u64, DecodedTokenContext> =
|
||||||
@ -102,33 +106,49 @@ fn executor_status_poller(
|
|||||||
|
|
||||||
// Iterate through all the decoded token
|
// Iterate through all the decoded token
|
||||||
for step in responses.deref() {
|
for step in responses.deref() {
|
||||||
let request_id = step.request_id;
|
match in_flights.get(&step.request_id) {
|
||||||
|
|
||||||
match in_flights.get(&request_id) {
|
|
||||||
Some(ctx) => {
|
Some(ctx) => {
|
||||||
info!("New token for {} -> {}", request_id, step.token_id);
|
debug!(
|
||||||
|
"{} -> (token={}, final={})",
|
||||||
|
step.request_id, step.token_id, step.is_final
|
||||||
|
);
|
||||||
|
|
||||||
|
// If no error, let's forward to post-processor
|
||||||
if !step.has_error {
|
if !step.has_error {
|
||||||
let req_group = grouper.entry(request_id).or_insert(
|
let req_group = grouper.entry(step.request_id).or_insert(
|
||||||
DecodedTokenContext {
|
DecodedTokenContext {
|
||||||
tokens: vec![],
|
tokens: vec![],
|
||||||
ctx: ctx.streamer.clone(), // Arc::clone() = cheap
|
ctx: ctx.streamer.clone(), // Arc::clone() = cheap
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
req_group.tokens.push(step.clone()); // Should be ultra cheap
|
req_group.tokens.push(step.clone()); // Should be ultra cheap
|
||||||
|
|
||||||
if step.is_final {
|
|
||||||
let _ = in_flights.remove(&step.request_id);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
warn!(
|
warn!(
|
||||||
"Error for request: {} -> {}",
|
"Error for request: {} -> {}",
|
||||||
request_id, &step.error_msg
|
step.request_id, &step.error_msg
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// TODO: Send something back to the postprocessor for the client?
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from tracked requests
|
||||||
|
if step.is_final {
|
||||||
|
let _ = in_flights.remove(&step.request_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
error!("Got step for untracked request {}", request_id);
|
if step.has_error {
|
||||||
|
error!(
|
||||||
|
"Untracked request {} -> {}",
|
||||||
|
step.request_id, &step.error_msg
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
error!(
|
||||||
|
"Got step for untracked request {}",
|
||||||
|
step.request_id
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -275,18 +295,16 @@ impl TensorRtLlmBackendV2 {
|
|||||||
|
|
||||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||||
if request.top_n_tokens > 1 {
|
if request.top_n_tokens > 1 {
|
||||||
return Err(InferError::ValidationError(
|
return Err(InferError::ValidationError(TopNTokensDisabled));
|
||||||
ValidationError::TopNTokensDisabled,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Is it really needed? How can it be validated before?
|
// TODO: Is it really needed? How can it be validated before?
|
||||||
if request.parameters.grammar.is_some() {
|
if request.parameters.grammar.is_some() {
|
||||||
return Err(InferError::ValidationError(ValidationError::Grammar));
|
return Err(InferError::ValidationError(Grammar));
|
||||||
}
|
}
|
||||||
|
|
||||||
match request.inputs.len() {
|
match request.inputs.len() {
|
||||||
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
0 => Err(InferError::ValidationError(EmptyInput)),
|
||||||
2.. => Err(InferError::GenerationError(
|
2.. => Err(InferError::GenerationError(
|
||||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
)),
|
)),
|
||||||
|
Loading…
Reference in New Issue
Block a user