mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 17:52:09 +00:00
(post) impl postprocessing
This commit is contained in:
parent
fa63db0d07
commit
984ae9798f
@ -3,22 +3,23 @@ use std::ops::Deref;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::{UniquePtr};
|
use cxx::UniquePtr;
|
||||||
use hashbrown::{HashMap};
|
use hashbrown::HashMap;
|
||||||
use log::warn;
|
use log::warn;
|
||||||
use tokenizers::{Encoding, Tokenizer};
|
use tokenizers::{Encoding, Tokenizer};
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::task::{spawn_blocking, JoinHandle};
|
use tokio::task::{JoinHandle, spawn_blocking};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error};
|
use tracing::{debug, error};
|
||||||
|
|
||||||
|
use text_generation_router::{FinishReason, Token};
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||||
use text_generation_router::validation::ValidationError::{
|
use text_generation_router::validation::ValidationError::{
|
||||||
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||||
};
|
};
|
||||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
@ -71,6 +72,8 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
|||||||
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
|
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
|
||||||
struct DecodedTokenContext {
|
struct DecodedTokenContext {
|
||||||
token: DecodedToken,
|
token: DecodedToken,
|
||||||
|
start: Option<Instant>,
|
||||||
|
queued: Instant,
|
||||||
channel: UnboundedSender<InferResult<InferStreamResponse>>,
|
channel: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,12 +134,14 @@ fn executor_status_looper(
|
|||||||
// Iterate through all the decoded token
|
// Iterate through all the decoded token
|
||||||
for step in responses.deref() {
|
for step in responses.deref() {
|
||||||
if let Some(ctx) = in_flights.get(&step.request_id) {
|
if let Some(ctx) = in_flights.get(&step.request_id) {
|
||||||
|
|
||||||
// Remove from tracked requests
|
// Remove from tracked requests
|
||||||
let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
let parcel =
|
||||||
token: dt,
|
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
||||||
channel: ctx.streamer.clone(),
|
token: dt,
|
||||||
});
|
start: ctx.start,
|
||||||
|
queued: ctx.queued,
|
||||||
|
channel: ctx.streamer.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
// Submit the work to p:the post_processor
|
// Submit the work to p:the post_processor
|
||||||
let posted = post_processor_sender.send((step.request_id, parcel));
|
let posted = post_processor_sender.send((step.request_id, parcel));
|
||||||
@ -148,7 +153,7 @@ fn executor_status_looper(
|
|||||||
} else {
|
} else {
|
||||||
warn!("Untracked request {}", step.request_id,);
|
warn!("Untracked request {}", step.request_id,);
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
Err(ref err) => {
|
Err(ref err) => {
|
||||||
error!("Failed to get responses from the executor: {}.", err.what());
|
error!("Failed to get responses from the executor: {}.", err.what());
|
||||||
@ -176,12 +181,60 @@ fn post_processor_looper(
|
|||||||
|
|
||||||
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
||||||
let state = states.entry(request_id).or_insert(vec![]);
|
let state = states.entry(request_id).or_insert(vec![]);
|
||||||
|
|
||||||
|
match decoded {
|
||||||
|
Ok(ctx) => {
|
||||||
|
state.push(ctx.token.id);
|
||||||
|
let out = match tokenizer.decode(&[ctx.token.id], false) {
|
||||||
|
Ok(text) => {
|
||||||
|
let is_special =
|
||||||
|
tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
|
let token = Token {
|
||||||
|
id: ctx.token.id,
|
||||||
|
text,
|
||||||
|
logprob: ctx.token.log_prob,
|
||||||
|
special: is_special,
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = if !ctx.token.is_final {
|
||||||
|
InferStreamResponse::Intermediate {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let text = tokenizer.decode(&state, true);
|
||||||
|
let generated_text = GeneratedText {
|
||||||
|
text: text.unwrap(),
|
||||||
|
generated_tokens: state.len() as u32,
|
||||||
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text,
|
||||||
|
start: ctx.start.unwrap(),
|
||||||
|
queued: ctx.queued,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
Err(err) => Err(GenerationError(err.to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(_) = ctx.channel.send(out) {
|
||||||
|
warn!("Failed to send decoded token back to the user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||||
unsafe impl Send for crate::ffi::TensorRtLlmBackendImpl {}
|
|
||||||
|
|
||||||
pub struct TensorRtLlmBackendV2 {
|
pub struct TensorRtLlmBackendV2 {
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user