mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add all the necessary plumbery to return the generated content
This commit is contained in:
parent
ce715c76f8
commit
69674a3a2d
@ -15,7 +15,7 @@ use tokio::sync::RwLock;
|
|||||||
use tokio::time::{Instant, sleep};
|
use tokio::time::{Instant, sleep};
|
||||||
use tokio_stream::{Stream, StreamExt};
|
use tokio_stream::{Stream, StreamExt};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{instrument, Level, span};
|
use tracing::{instrument, Level, span, Span};
|
||||||
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
@ -25,42 +25,23 @@ use text_generation_router::validation::ValidationError::UnsupportedModality;
|
|||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||||
|
|
||||||
// macro_rules! propagate {
|
|
||||||
// ($ctx: expr, $res: expr) => {
|
|
||||||
// $ctx.sender
|
|
||||||
// .send($res)
|
|
||||||
// .expect("Failed to propagate error back to the transport layer")
|
|
||||||
// };
|
|
||||||
// }
|
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
/// Holds the user provided input to be executed along with a channel allowing
|
|
||||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
|
||||||
// pub struct InferenceContext {
|
|
||||||
// /// User provided request
|
|
||||||
// request: ValidGenerateRequest,
|
|
||||||
//
|
|
||||||
// /// Inter-process communication handler moving token from the executor thread to the HTTP server
|
|
||||||
// sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
//
|
|
||||||
// /// Pin the instant this inference context was submitted
|
|
||||||
// when: Instant,
|
|
||||||
//
|
|
||||||
// /// Span that will live as long as entry
|
|
||||||
// span: Span,
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub(crate) struct Generation {
|
pub(crate) struct Generation {
|
||||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||||
done: Arc<AtomicBool>,
|
done: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Holds the user provided input to be executed along with a channel allowing
|
||||||
|
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct GenerationContext {
|
pub struct GenerationContext {
|
||||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
tokens: Vec<u32>,
|
||||||
done: Arc<AtomicBool>,
|
done: Arc<AtomicBool>,
|
||||||
|
start: Instant,
|
||||||
|
span: Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Stream for Generation {
|
impl Stream for Generation {
|
||||||
@ -175,7 +156,10 @@ impl TensorRtLlmBackend {
|
|||||||
let ctx = Box::new(GenerationContext {
|
let ctx = Box::new(GenerationContext {
|
||||||
sender: sender.clone(),
|
sender: sender.clone(),
|
||||||
tokenizer: tokenizer.clone(),
|
tokenizer: tokenizer.clone(),
|
||||||
|
tokens: vec![],
|
||||||
done: Arc::clone(&generation.done),
|
done: Arc::clone(&generation.done),
|
||||||
|
start: Instant::now(),
|
||||||
|
span: Span::current(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
||||||
@ -209,45 +193,50 @@ impl TensorRtLlmBackend {
|
|||||||
request_id,
|
request_id,
|
||||||
ctx_,
|
ctx_,
|
||||||
|ctx: *mut GenerationContext,
|
|ctx: *mut GenerationContext,
|
||||||
token: u32,
|
token_id: u32,
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
is_final: bool| {
|
is_final: bool| {
|
||||||
// let text = ctx
|
let inner_ctx = &mut *ctx;
|
||||||
// .tokenizer
|
inner_ctx.tokens.push(token_id);
|
||||||
// .decode(&[token], true)
|
|
||||||
// .expect("Failed to decode token");
|
let text = inner_ctx
|
||||||
info!("Decoded token: {}", token);
|
.tokenizer
|
||||||
|
.decode(&[token_id], true)
|
||||||
|
.expect("Failed to decode token");
|
||||||
|
|
||||||
|
let token = Token {
|
||||||
|
id: token_id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special: false,
|
||||||
|
};
|
||||||
|
|
||||||
let out = if is_final {
|
let out = if is_final {
|
||||||
(*ctx).done.store(true, Ordering::Relaxed);
|
inner_ctx.done.store(true, Ordering::Relaxed);
|
||||||
|
let generated_text = inner_ctx
|
||||||
|
.tokenizer
|
||||||
|
.decode(&inner_ctx.tokens, true)
|
||||||
|
.expect("Failed to decode generated_tokens");
|
||||||
|
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
token: Token {
|
token,
|
||||||
id: token,
|
|
||||||
text: "".into(),
|
|
||||||
logprob,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
top_tokens: vec![],
|
||||||
generated_text: GeneratedText {
|
generated_text: GeneratedText {
|
||||||
text: "".into(),
|
text: generated_text,
|
||||||
generated_tokens: u32::MAX,
|
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||||
finish_reason: FinishReason::EndOfSequenceToken,
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
seed: None,
|
seed: None,
|
||||||
},
|
},
|
||||||
start: Instant::now(),
|
start: inner_ctx.start,
|
||||||
queued: Instant::now(),
|
queued: Instant::now(),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
InferStreamResponse::Intermediate {
|
InferStreamResponse::Intermediate {
|
||||||
token: Token {
|
token,
|
||||||
id: token,
|
|
||||||
text: "".into(),
|
|
||||||
logprob,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
top_tokens: vec![],
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
(*ctx)
|
inner_ctx
|
||||||
.sender
|
.sender
|
||||||
.send(Ok(out))
|
.send(Ok(out))
|
||||||
.expect("Failed to send back generated token");
|
.expect("Failed to send back generated token");
|
||||||
|
Loading…
Reference in New Issue
Block a user