add all the necessary plumbery to return the generated content

This commit is contained in:
Morgan Funtowicz 2024-07-17 22:12:49 +00:00
parent ce715c76f8
commit 69674a3a2d

View File

@ -15,7 +15,7 @@ use tokio::sync::RwLock;
use tokio::time::{Instant, sleep};
use tokio_stream::{Stream, StreamExt};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{instrument, Level, span};
use tracing::{instrument, Level, span, Span};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
@ -25,42 +25,23 @@ use text_generation_router::validation::ValidationError::UnsupportedModality;
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
// macro_rules! propagate {
// ($ctx: expr, $res: expr) => {
// $ctx.sender
// .send($res)
// .expect("Failed to propagate error back to the transport layer")
// };
// }
type InferResult<T> = Result<T, InferError>;
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
// pub struct InferenceContext {
// /// User provided request
// request: ValidGenerateRequest,
//
// /// Inter-process communication handler moving token from the executor thread to the HTTP server
// sender: UnboundedSender<InferResult<InferStreamResponse>>,
//
// /// Pin the instant this inference context was submitted
// when: Instant,
//
// /// Span that will live as long as entry
// span: Span,
// }
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
#[derive(Clone)]
pub struct GenerationContext {
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokenizer: Arc<Tokenizer>,
tokens: Vec<u32>,
done: Arc<AtomicBool>,
start: Instant,
span: Span,
}
impl Stream for Generation {
@ -175,7 +156,10 @@ impl TensorRtLlmBackend {
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer: tokenizer.clone(),
tokens: vec![],
done: Arc::clone(&generation.done),
start: Instant::now(),
span: Span::current(),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
@ -209,45 +193,50 @@ impl TensorRtLlmBackend {
request_id,
ctx_,
|ctx: *mut GenerationContext,
token: u32,
token_id: u32,
logprob: f32,
is_final: bool| {
// let text = ctx
// .tokenizer
// .decode(&[token], true)
// .expect("Failed to decode token");
info!("Decoded token: {}", token);
let inner_ctx = &mut *ctx;
inner_ctx.tokens.push(token_id);
let text = inner_ctx
.tokenizer
.decode(&[token_id], true)
.expect("Failed to decode token");
let token = Token {
id: token_id,
text,
logprob,
special: false,
};
let out = if is_final {
(*ctx).done.store(true, Ordering::Relaxed);
inner_ctx.done.store(true, Ordering::Relaxed);
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
InferStreamResponse::End {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: "".into(),
generated_tokens: u32::MAX,
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: Instant::now(),
start: inner_ctx.start,
queued: Instant::now(),
}
} else {
InferStreamResponse::Intermediate {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
token,
top_tokens: vec![],
}
};
(*ctx)
inner_ctx
.sender
.send(Ok(out))
.expect("Failed to send back generated token");