mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
impl the rust backend which currently cannot move the actual computation in background thread
This commit is contained in:
parent
518d9a9e0b
commit
b291be64a0
@ -8,13 +8,16 @@ use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackendImpl};
|
||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||
|
||||
struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
||||
|
||||
pub struct TrtLLmBackend {
|
||||
tokenizer: Tokenizer,
|
||||
@ -30,42 +33,32 @@ impl TrtLLmBackend {
|
||||
engine_folder: P,
|
||||
) -> Result<Self, TensorRtLlmBackendError> {
|
||||
let engine_folder = engine_folder.as_ref();
|
||||
let inner = create_trtllm_backend(engine_folder.to_str().unwrap(), "");
|
||||
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
|
||||
|
||||
Ok(Self {
|
||||
tokenizer,
|
||||
inner: RefCell::new(inner),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TrtLLmBackend {
|
||||
fn schedule(
|
||||
fn infer_text(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
let ctx = Box::new(GenerationContext(sender));
|
||||
ctx: GenerationContext,
|
||||
text: &str,
|
||||
params: ValidParameters,
|
||||
) -> InferResult<()> {
|
||||
// Keep track of processing time
|
||||
let start = Instant::now();
|
||||
|
||||
// Unpack parameters
|
||||
let params = request.parameters;
|
||||
|
||||
// Currently we handle single chunk of text
|
||||
if request.inputs.len() == 1 {
|
||||
match request
|
||||
.inputs
|
||||
.first()
|
||||
.expect("Failed to access the first chunk")
|
||||
{
|
||||
Chunk::Text(text) => {
|
||||
// Encode the input
|
||||
let ctx = Box::new(ctx);
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(&**text, true)
|
||||
.encode(text, true)
|
||||
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||
|
||||
let _start = Instant::now();
|
||||
let _request_id = self
|
||||
// Submit the request to the backend and retrieve the handle to query its status
|
||||
let request_id = self
|
||||
.inner
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
@ -79,52 +72,83 @@ impl Backend for TrtLLmBackend {
|
||||
params.seed,
|
||||
);
|
||||
|
||||
// spawn_blocking(|| {
|
||||
// // Stream generated tokens
|
||||
// let num_generated_tokens = self
|
||||
// .inner
|
||||
// .borrow_mut()
|
||||
// .as_mut()
|
||||
// .expect("Failed to retrieve pointer to TRTLLM backend")
|
||||
// .stream(request_id, ctx, |token, step, is_final| {
|
||||
// // self.tokenizer.decode(&*[token], true).unwrap();
|
||||
// let token = Token {
|
||||
// id: token,
|
||||
// text: String::from(""),
|
||||
// logprob: 1.0f32,
|
||||
// special: false,
|
||||
// };
|
||||
//
|
||||
// sender
|
||||
// .send(Ok(InferStreamResponse::Intermediate {
|
||||
// token,
|
||||
// top_tokens: vec![],
|
||||
// }))
|
||||
// .unwrap()
|
||||
// });
|
||||
//
|
||||
// // Notify the end
|
||||
// Ok(InferStreamResponse::End {
|
||||
// token: Token {
|
||||
// id: 0,
|
||||
// text: String::from(""),
|
||||
// logprob: 1.0f32,
|
||||
// special: false,
|
||||
// },
|
||||
// top_tokens: vec![],
|
||||
// generated_text: GeneratedText {
|
||||
// text: String::from(""),
|
||||
// generated_tokens: num_generated_tokens,
|
||||
// finish_reason: FinishReason::EndOfSequenceToken,
|
||||
// seed: Some(params.seed),
|
||||
// },
|
||||
// start,
|
||||
// queued: Instant::now(),
|
||||
// })
|
||||
// Stream generated tokens
|
||||
// spawn_blocking(move || {
|
||||
let num_generated_tokens = self
|
||||
.inner
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.expect("Failed to retrieve pointer to TRTLLM backend")
|
||||
.stream(ctx, request_id, |ctx, token, step, is_final| {
|
||||
// self.tokenizer.decode(&*[token], true).unwrap();
|
||||
let sender = ctx.0;
|
||||
let token = Token {
|
||||
id: token,
|
||||
text: String::from(""),
|
||||
logprob: 1.0f32,
|
||||
special: false,
|
||||
};
|
||||
|
||||
sender
|
||||
.send(Ok(InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
}))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
// Notify the end
|
||||
let _ = ctx.0.send(Ok(InferStreamResponse::End {
|
||||
token: Token {
|
||||
id: 0,
|
||||
text: String::from(""),
|
||||
logprob: 1.0f32,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: String::from(""),
|
||||
generated_tokens: num_generated_tokens,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: Some(params.seed),
|
||||
},
|
||||
start,
|
||||
queued: Instant::now(),
|
||||
}));
|
||||
// });
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Chunk::Image(_) => {}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TrtLLmBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
let ctx = GenerationContext(sender);
|
||||
|
||||
// Unpack parameters
|
||||
let params = request.parameters;
|
||||
|
||||
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
|
||||
let input = match request.inputs.len() {
|
||||
0 => Err(InferError::GenerationError("No input provided".into())),
|
||||
1 => Ok(request.inputs.first().unwrap()),
|
||||
_ => Err(InferError::GenerationError(format!(
|
||||
"Unsupported multi-chunks ({}) inference.",
|
||||
request.inputs.len()
|
||||
))),
|
||||
}?;
|
||||
|
||||
// Currently we handle single chunk of text
|
||||
match input {
|
||||
Chunk::Text(text) => {
|
||||
self.infer_text(ctx, &**text, params)?;
|
||||
}
|
||||
Chunk::Image(_) => panic!("Unsupported"),
|
||||
};
|
||||
|
||||
Ok(UnboundedReceiverStream::new(receiver))
|
||||
|
Loading…
Reference in New Issue
Block a user