impl the rust backend which currently cannot move the actual computation in background thread

This commit is contained in:
Morgan Funtowicz 2024-07-12 19:26:32 +00:00
parent 518d9a9e0b
commit b291be64a0

View File

@ -8,13 +8,16 @@ use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackendImpl};
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
type InferResult<T> = Result<T, InferError>;
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
pub struct TrtLLmBackend {
tokenizer: Tokenizer,
@ -30,42 +33,32 @@ impl TrtLLmBackend {
engine_folder: P,
) -> Result<Self, TensorRtLlmBackendError> {
let engine_folder = engine_folder.as_ref();
let inner = create_trtllm_backend(engine_folder.to_str().unwrap(), "");
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
Ok(Self {
tokenizer,
inner: RefCell::new(inner),
})
}
}
#[async_trait]
impl Backend for TrtLLmBackend {
fn schedule(
fn infer_text(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
let (sender, receiver) = mpsc::unbounded_channel();
let ctx = Box::new(GenerationContext(sender));
ctx: GenerationContext,
text: &str,
params: ValidParameters,
) -> InferResult<()> {
// Keep track of processing time
let start = Instant::now();
// Unpack parameters
let params = request.parameters;
// Currently we handle single chunk of text
if request.inputs.len() == 1 {
match request
.inputs
.first()
.expect("Failed to access the first chunk")
{
Chunk::Text(text) => {
// Encode the input
let ctx = Box::new(ctx);
let encoding = self
.tokenizer
.encode(&**text, true)
.encode(text, true)
.map_err(|e| InferError::ToolError(e.to_string()))?;
let _start = Instant::now();
let _request_id = self
// Submit the request to the backend and retrieve the handle to query its status
let request_id = self
.inner
.borrow_mut()
.as_mut()
@ -79,52 +72,83 @@ impl Backend for TrtLLmBackend {
params.seed,
);
// spawn_blocking(|| {
// // Stream generated tokens
// let num_generated_tokens = self
// .inner
// .borrow_mut()
// .as_mut()
// .expect("Failed to retrieve pointer to TRTLLM backend")
// .stream(request_id, ctx, |token, step, is_final| {
// // self.tokenizer.decode(&*[token], true).unwrap();
// let token = Token {
// id: token,
// text: String::from(""),
// logprob: 1.0f32,
// special: false,
// };
//
// sender
// .send(Ok(InferStreamResponse::Intermediate {
// token,
// top_tokens: vec![],
// }))
// .unwrap()
// });
//
// // Notify the end
// Ok(InferStreamResponse::End {
// token: Token {
// id: 0,
// text: String::from(""),
// logprob: 1.0f32,
// special: false,
// },
// top_tokens: vec![],
// generated_text: GeneratedText {
// text: String::from(""),
// generated_tokens: num_generated_tokens,
// finish_reason: FinishReason::EndOfSequenceToken,
// seed: Some(params.seed),
// },
// start,
// queued: Instant::now(),
// })
// Stream generated tokens
// spawn_blocking(move || {
let num_generated_tokens = self
.inner
.borrow_mut()
.as_mut()
.expect("Failed to retrieve pointer to TRTLLM backend")
.stream(ctx, request_id, |ctx, token, step, is_final| {
// self.tokenizer.decode(&*[token], true).unwrap();
let sender = ctx.0;
let token = Token {
id: token,
text: String::from(""),
logprob: 1.0f32,
special: false,
};
sender
.send(Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}))
.unwrap()
});
// Notify the end
let _ = ctx.0.send(Ok(InferStreamResponse::End {
token: Token {
id: 0,
text: String::from(""),
logprob: 1.0f32,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: String::from(""),
generated_tokens: num_generated_tokens,
finish_reason: FinishReason::EndOfSequenceToken,
seed: Some(params.seed),
},
start,
queued: Instant::now(),
}));
// });
Ok(())
}
Chunk::Image(_) => {}
}
#[async_trait]
impl Backend for TrtLLmBackend {
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
let (sender, receiver) = mpsc::unbounded_channel();
let ctx = GenerationContext(sender);
// Unpack parameters
let params = request.parameters;
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
let input = match request.inputs.len() {
0 => Err(InferError::GenerationError("No input provided".into())),
1 => Ok(request.inputs.first().unwrap()),
_ => Err(InferError::GenerationError(format!(
"Unsupported multi-chunks ({}) inference.",
request.inputs.len()
))),
}?;
// Currently we handle single chunk of text
match input {
Chunk::Text(text) => {
self.infer_text(ctx, &**text, params)?;
}
Chunk::Image(_) => panic!("Unsupported"),
};
Ok(UnboundedReceiverStream::new(receiver))