2024-07-11 21:24:32 +00:00
|
|
|
use std::cell::RefCell;
|
2024-07-01 13:53:23 +00:00
|
|
|
use std::path::Path;
|
|
|
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use cxx::UniquePtr;
|
2024-07-03 08:27:53 +00:00
|
|
|
use tokenizers::Tokenizer;
|
|
|
|
use tokio::sync::mpsc;
|
2024-07-11 21:24:32 +00:00
|
|
|
use tokio::time::Instant;
|
2024-06-30 21:37:20 +00:00
|
|
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
|
|
|
2024-07-12 19:26:32 +00:00
|
|
|
use text_generation_router::{FinishReason, Token};
|
|
|
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
|
|
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
|
2024-06-30 21:37:20 +00:00
|
|
|
|
2024-07-01 13:53:23 +00:00
|
|
|
use crate::errors::TensorRtLlmBackendError;
|
2024-07-12 19:26:32 +00:00
|
|
|
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
2024-07-11 21:24:32 +00:00
|
|
|
|
2024-07-12 19:26:32 +00:00
|
|
|
type InferResult<T> = Result<T, InferError>;
|
|
|
|
|
|
|
|
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
2024-07-01 13:53:23 +00:00
|
|
|
|
|
|
|
pub struct TrtLLmBackend {
|
2024-07-03 08:27:53 +00:00
|
|
|
tokenizer: Tokenizer,
|
2024-07-11 21:24:32 +00:00
|
|
|
inner: RefCell<UniquePtr<TensorRtLlmBackendImpl>>,
|
2024-07-01 13:53:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
unsafe impl Sync for TrtLLmBackend {}
|
|
|
|
unsafe impl Send for TrtLLmBackend {}
|
|
|
|
|
|
|
|
impl TrtLLmBackend {
|
2024-07-03 08:27:53 +00:00
|
|
|
pub fn new<P: AsRef<Path>>(
|
|
|
|
tokenizer: Tokenizer,
|
|
|
|
engine_folder: P,
|
|
|
|
) -> Result<Self, TensorRtLlmBackendError> {
|
2024-07-01 13:53:23 +00:00
|
|
|
let engine_folder = engine_folder.as_ref();
|
2024-07-12 19:26:32 +00:00
|
|
|
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
|
2024-06-30 21:37:20 +00:00
|
|
|
|
2024-07-11 21:24:32 +00:00
|
|
|
Ok(Self {
|
|
|
|
tokenizer,
|
|
|
|
inner: RefCell::new(inner),
|
|
|
|
})
|
2024-07-01 13:53:23 +00:00
|
|
|
}
|
2024-07-12 19:26:32 +00:00
|
|
|
|
|
|
|
fn infer_text(
|
|
|
|
&self,
|
|
|
|
ctx: GenerationContext,
|
|
|
|
text: &str,
|
|
|
|
params: ValidParameters,
|
|
|
|
) -> InferResult<()> {
|
|
|
|
// Keep track of processing time
|
|
|
|
let start = Instant::now();
|
|
|
|
|
|
|
|
// Encode the input
|
|
|
|
let ctx = Box::new(ctx);
|
|
|
|
let encoding = self
|
|
|
|
.tokenizer
|
|
|
|
.encode(text, true)
|
|
|
|
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
|
|
|
|
|
|
|
// Submit the request to the backend and retrieve the handle to query its status
|
|
|
|
let request_id = self
|
|
|
|
.inner
|
|
|
|
.borrow_mut()
|
|
|
|
.as_mut()
|
|
|
|
.expect("Failed to retrieve pointer to TRTLLM backend")
|
|
|
|
.submit(
|
|
|
|
encoding.get_ids(),
|
|
|
|
128,
|
|
|
|
params.top_k as i32,
|
|
|
|
params.top_p,
|
|
|
|
params.temperature,
|
|
|
|
params.seed,
|
|
|
|
);
|
|
|
|
|
|
|
|
// Stream generated tokens
|
|
|
|
// spawn_blocking(move || {
|
|
|
|
let num_generated_tokens = self
|
|
|
|
.inner
|
|
|
|
.borrow_mut()
|
|
|
|
.as_mut()
|
|
|
|
.expect("Failed to retrieve pointer to TRTLLM backend")
|
|
|
|
.stream(ctx, request_id, |ctx, token, step, is_final| {
|
|
|
|
// self.tokenizer.decode(&*[token], true).unwrap();
|
|
|
|
let sender = ctx.0;
|
|
|
|
let token = Token {
|
|
|
|
id: token,
|
|
|
|
text: String::from(""),
|
|
|
|
logprob: 1.0f32,
|
|
|
|
special: false,
|
|
|
|
};
|
|
|
|
|
|
|
|
sender
|
|
|
|
.send(Ok(InferStreamResponse::Intermediate {
|
|
|
|
token,
|
|
|
|
top_tokens: vec![],
|
|
|
|
}))
|
|
|
|
.unwrap()
|
|
|
|
});
|
|
|
|
|
|
|
|
// Notify the end
|
|
|
|
let _ = ctx.0.send(Ok(InferStreamResponse::End {
|
|
|
|
token: Token {
|
|
|
|
id: 0,
|
|
|
|
text: String::from(""),
|
|
|
|
logprob: 1.0f32,
|
|
|
|
special: false,
|
|
|
|
},
|
|
|
|
top_tokens: vec![],
|
|
|
|
generated_text: GeneratedText {
|
|
|
|
text: String::from(""),
|
|
|
|
generated_tokens: num_generated_tokens,
|
|
|
|
finish_reason: FinishReason::EndOfSequenceToken,
|
|
|
|
seed: Some(params.seed),
|
|
|
|
},
|
|
|
|
start,
|
|
|
|
queued: Instant::now(),
|
|
|
|
}));
|
|
|
|
// });
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
2024-07-01 13:53:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl Backend for TrtLLmBackend {
|
2024-06-30 21:37:20 +00:00
|
|
|
fn schedule(
|
|
|
|
&self,
|
2024-07-03 08:27:53 +00:00
|
|
|
request: ValidGenerateRequest,
|
2024-07-12 19:26:32 +00:00
|
|
|
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
2024-07-03 08:27:53 +00:00
|
|
|
let (sender, receiver) = mpsc::unbounded_channel();
|
2024-07-12 19:26:32 +00:00
|
|
|
let ctx = GenerationContext(sender);
|
2024-07-11 21:24:32 +00:00
|
|
|
|
|
|
|
// Unpack parameters
|
|
|
|
let params = request.parameters;
|
|
|
|
|
2024-07-12 19:26:32 +00:00
|
|
|
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
|
|
|
|
let input = match request.inputs.len() {
|
|
|
|
0 => Err(InferError::GenerationError("No input provided".into())),
|
|
|
|
1 => Ok(request.inputs.first().unwrap()),
|
|
|
|
_ => Err(InferError::GenerationError(format!(
|
|
|
|
"Unsupported multi-chunks ({}) inference.",
|
|
|
|
request.inputs.len()
|
|
|
|
))),
|
|
|
|
}?;
|
|
|
|
|
2024-07-11 21:24:32 +00:00
|
|
|
// Currently we handle single chunk of text
|
2024-07-12 19:26:32 +00:00
|
|
|
match input {
|
|
|
|
Chunk::Text(text) => {
|
|
|
|
self.infer_text(ctx, &**text, params)?;
|
2024-07-11 21:24:32 +00:00
|
|
|
}
|
2024-07-12 19:26:32 +00:00
|
|
|
Chunk::Image(_) => panic!("Unsupported"),
|
2024-07-11 21:24:32 +00:00
|
|
|
};
|
2024-07-03 08:27:53 +00:00
|
|
|
|
|
|
|
Ok(UnboundedReceiverStream::new(receiver))
|
2024-06-30 21:37:20 +00:00
|
|
|
}
|
|
|
|
|
2024-07-01 13:53:23 +00:00
|
|
|
async fn health(&self, _current_health: bool) -> bool {
|
2024-07-11 21:24:32 +00:00
|
|
|
self.inner.borrow_mut().is_ready()
|
2024-06-30 21:37:20 +00:00
|
|
|
}
|
|
|
|
}
|