mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add all the necessary plumbery to return the generated content
This commit is contained in:
parent
ce715c76f8
commit
69674a3a2d
@ -15,7 +15,7 @@ use tokio::sync::RwLock;
|
||||
use tokio::time::{Instant, sleep};
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{instrument, Level, span};
|
||||
use tracing::{instrument, Level, span, Span};
|
||||
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
@ -25,42 +25,23 @@ use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||
|
||||
// macro_rules! propagate {
|
||||
// ($ctx: expr, $res: expr) => {
|
||||
// $ctx.sender
|
||||
// .send($res)
|
||||
// .expect("Failed to propagate error back to the transport layer")
|
||||
// };
|
||||
// }
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
/// Holds the user provided input to be executed along with a channel allowing
|
||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||
// pub struct InferenceContext {
|
||||
// /// User provided request
|
||||
// request: ValidGenerateRequest,
|
||||
//
|
||||
// /// Inter-process communication handler moving token from the executor thread to the HTTP server
|
||||
// sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
//
|
||||
// /// Pin the instant this inference context was submitted
|
||||
// when: Instant,
|
||||
//
|
||||
// /// Span that will live as long as entry
|
||||
// span: Span,
|
||||
// }
|
||||
|
||||
pub(crate) struct Generation {
|
||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
done: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Holds the user provided input to be executed along with a channel allowing
|
||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||
#[derive(Clone)]
|
||||
pub struct GenerationContext {
|
||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
tokens: Vec<u32>,
|
||||
done: Arc<AtomicBool>,
|
||||
start: Instant,
|
||||
span: Span,
|
||||
}
|
||||
|
||||
impl Stream for Generation {
|
||||
@ -175,7 +156,10 @@ impl TensorRtLlmBackend {
|
||||
let ctx = Box::new(GenerationContext {
|
||||
sender: sender.clone(),
|
||||
tokenizer: tokenizer.clone(),
|
||||
tokens: vec![],
|
||||
done: Arc::clone(&generation.done),
|
||||
start: Instant::now(),
|
||||
span: Span::current(),
|
||||
});
|
||||
|
||||
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
||||
@ -209,45 +193,50 @@ impl TensorRtLlmBackend {
|
||||
request_id,
|
||||
ctx_,
|
||||
|ctx: *mut GenerationContext,
|
||||
token: u32,
|
||||
token_id: u32,
|
||||
logprob: f32,
|
||||
is_final: bool| {
|
||||
// let text = ctx
|
||||
// .tokenizer
|
||||
// .decode(&[token], true)
|
||||
// .expect("Failed to decode token");
|
||||
info!("Decoded token: {}", token);
|
||||
let inner_ctx = &mut *ctx;
|
||||
inner_ctx.tokens.push(token_id);
|
||||
|
||||
let text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&[token_id], true)
|
||||
.expect("Failed to decode token");
|
||||
|
||||
let token = Token {
|
||||
id: token_id,
|
||||
text,
|
||||
logprob,
|
||||
special: false,
|
||||
};
|
||||
|
||||
let out = if is_final {
|
||||
(*ctx).done.store(true, Ordering::Relaxed);
|
||||
inner_ctx.done.store(true, Ordering::Relaxed);
|
||||
let generated_text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&inner_ctx.tokens, true)
|
||||
.expect("Failed to decode generated_tokens");
|
||||
|
||||
InferStreamResponse::End {
|
||||
token: Token {
|
||||
id: token,
|
||||
text: "".into(),
|
||||
logprob,
|
||||
special: false,
|
||||
},
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: "".into(),
|
||||
generated_tokens: u32::MAX,
|
||||
text: generated_text,
|
||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
},
|
||||
start: Instant::now(),
|
||||
start: inner_ctx.start,
|
||||
queued: Instant::now(),
|
||||
}
|
||||
} else {
|
||||
InferStreamResponse::Intermediate {
|
||||
token: Token {
|
||||
id: token,
|
||||
text: "".into(),
|
||||
logprob,
|
||||
special: false,
|
||||
},
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
}
|
||||
};
|
||||
(*ctx)
|
||||
inner_ctx
|
||||
.sender
|
||||
.send(Ok(out))
|
||||
.expect("Failed to send back generated token");
|
||||
|
Loading…
Reference in New Issue
Block a user