2024-07-11 21:24:32 +00:00
|
|
|
use std::cell::RefCell;
|
2024-07-01 13:53:23 +00:00
|
|
|
use std::path::Path;
|
|
|
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use cxx::UniquePtr;
|
2024-07-03 08:27:53 +00:00
|
|
|
use tokenizers::Tokenizer;
|
|
|
|
use tokio::sync::mpsc;
|
2024-07-11 21:24:32 +00:00
|
|
|
use tokio::time::Instant;
|
2024-06-30 21:37:20 +00:00
|
|
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
|
|
|
|
|
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
2024-07-11 21:24:32 +00:00
|
|
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
2024-06-30 21:37:20 +00:00
|
|
|
|
2024-07-01 13:53:23 +00:00
|
|
|
use crate::errors::TensorRtLlmBackendError;
|
2024-07-11 21:24:32 +00:00
|
|
|
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackendImpl};
|
|
|
|
|
|
|
|
struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
2024-07-01 13:53:23 +00:00
|
|
|
|
|
|
|
pub struct TrtLLmBackend {
|
2024-07-03 08:27:53 +00:00
|
|
|
tokenizer: Tokenizer,
|
2024-07-11 21:24:32 +00:00
|
|
|
inner: RefCell<UniquePtr<TensorRtLlmBackendImpl>>,
|
2024-07-01 13:53:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
unsafe impl Sync for TrtLLmBackend {}
|
|
|
|
unsafe impl Send for TrtLLmBackend {}
|
|
|
|
|
|
|
|
impl TrtLLmBackend {
|
2024-07-03 08:27:53 +00:00
|
|
|
pub fn new<P: AsRef<Path>>(
|
|
|
|
tokenizer: Tokenizer,
|
|
|
|
engine_folder: P,
|
|
|
|
) -> Result<Self, TensorRtLlmBackendError> {
|
2024-07-01 13:53:23 +00:00
|
|
|
let engine_folder = engine_folder.as_ref();
|
2024-07-11 21:24:32 +00:00
|
|
|
let inner = create_trtllm_backend(engine_folder.to_str().unwrap(), "");
|
2024-06-30 21:37:20 +00:00
|
|
|
|
2024-07-11 21:24:32 +00:00
|
|
|
Ok(Self {
|
|
|
|
tokenizer,
|
|
|
|
inner: RefCell::new(inner),
|
|
|
|
})
|
2024-07-01 13:53:23 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl Backend for TrtLLmBackend {
|
2024-06-30 21:37:20 +00:00
|
|
|
fn schedule(
|
|
|
|
&self,
|
2024-07-03 08:27:53 +00:00
|
|
|
request: ValidGenerateRequest,
|
2024-06-30 21:37:20 +00:00
|
|
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
2024-07-03 08:27:53 +00:00
|
|
|
let (sender, receiver) = mpsc::unbounded_channel();
|
2024-07-11 21:24:32 +00:00
|
|
|
let ctx = Box::new(GenerationContext(sender));
|
|
|
|
|
|
|
|
// Unpack parameters
|
|
|
|
let params = request.parameters;
|
|
|
|
|
|
|
|
// Currently we handle single chunk of text
|
|
|
|
if request.inputs.len() == 1 {
|
|
|
|
match request
|
|
|
|
.inputs
|
|
|
|
.first()
|
|
|
|
.expect("Failed to access the first chunk")
|
|
|
|
{
|
|
|
|
Chunk::Text(text) => {
|
|
|
|
let encoding = self
|
|
|
|
.tokenizer
|
|
|
|
.encode(&**text, true)
|
|
|
|
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
|
|
|
|
|
|
|
let _start = Instant::now();
|
|
|
|
let _request_id = self
|
|
|
|
.inner
|
|
|
|
.borrow_mut()
|
|
|
|
.as_mut()
|
|
|
|
.expect("Failed to retrieve pointer to TRTLLM backend")
|
|
|
|
.submit(
|
|
|
|
encoding.get_ids(),
|
|
|
|
128,
|
|
|
|
params.top_k as i32,
|
|
|
|
params.top_p,
|
|
|
|
params.temperature,
|
|
|
|
params.seed,
|
|
|
|
);
|
|
|
|
|
|
|
|
// spawn_blocking(|| {
|
|
|
|
// // Stream generated tokens
|
|
|
|
// let num_generated_tokens = self
|
|
|
|
// .inner
|
|
|
|
// .borrow_mut()
|
|
|
|
// .as_mut()
|
|
|
|
// .expect("Failed to retrieve pointer to TRTLLM backend")
|
|
|
|
// .stream(request_id, ctx, |token, step, is_final| {
|
|
|
|
// // self.tokenizer.decode(&*[token], true).unwrap();
|
|
|
|
// let token = Token {
|
|
|
|
// id: token,
|
|
|
|
// text: String::from(""),
|
|
|
|
// logprob: 1.0f32,
|
|
|
|
// special: false,
|
|
|
|
// };
|
|
|
|
//
|
|
|
|
// sender
|
|
|
|
// .send(Ok(InferStreamResponse::Intermediate {
|
|
|
|
// token,
|
|
|
|
// top_tokens: vec![],
|
|
|
|
// }))
|
|
|
|
// .unwrap()
|
|
|
|
// });
|
|
|
|
//
|
|
|
|
// // Notify the end
|
|
|
|
// Ok(InferStreamResponse::End {
|
|
|
|
// token: Token {
|
|
|
|
// id: 0,
|
|
|
|
// text: String::from(""),
|
|
|
|
// logprob: 1.0f32,
|
|
|
|
// special: false,
|
|
|
|
// },
|
|
|
|
// top_tokens: vec![],
|
|
|
|
// generated_text: GeneratedText {
|
|
|
|
// text: String::from(""),
|
|
|
|
// generated_tokens: num_generated_tokens,
|
|
|
|
// finish_reason: FinishReason::EndOfSequenceToken,
|
|
|
|
// seed: Some(params.seed),
|
|
|
|
// },
|
|
|
|
// start,
|
|
|
|
// queued: Instant::now(),
|
|
|
|
// })
|
|
|
|
// });
|
|
|
|
}
|
|
|
|
Chunk::Image(_) => {}
|
|
|
|
}
|
|
|
|
};
|
2024-07-03 08:27:53 +00:00
|
|
|
|
|
|
|
Ok(UnboundedReceiverStream::new(receiver))
|
2024-06-30 21:37:20 +00:00
|
|
|
}
|
|
|
|
|
2024-07-01 13:53:23 +00:00
|
|
|
async fn health(&self, _current_health: bool) -> bool {
|
2024-07-11 21:24:32 +00:00
|
|
|
self.inner.borrow_mut().is_ready()
|
2024-06-30 21:37:20 +00:00
|
|
|
}
|
|
|
|
}
|