mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
impl the rust backend which currently cannot move the actual computation in background thread
This commit is contained in:
parent
518d9a9e0b
commit
b291be64a0
@ -8,13 +8,16 @@ use tokio::sync::mpsc;
|
|||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||||
|
|
||||||
struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
|
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
||||||
|
|
||||||
pub struct TrtLLmBackend {
|
pub struct TrtLLmBackend {
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
@ -30,42 +33,32 @@ impl TrtLLmBackend {
|
|||||||
engine_folder: P,
|
engine_folder: P,
|
||||||
) -> Result<Self, TensorRtLlmBackendError> {
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
let engine_folder = engine_folder.as_ref();
|
let engine_folder = engine_folder.as_ref();
|
||||||
let inner = create_trtllm_backend(engine_folder.to_str().unwrap(), "");
|
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
inner: RefCell::new(inner),
|
inner: RefCell::new(inner),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
fn infer_text(
|
||||||
impl Backend for TrtLLmBackend {
|
|
||||||
fn schedule(
|
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
ctx: GenerationContext,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
text: &str,
|
||||||
let (sender, receiver) = mpsc::unbounded_channel();
|
params: ValidParameters,
|
||||||
let ctx = Box::new(GenerationContext(sender));
|
) -> InferResult<()> {
|
||||||
|
// Keep track of processing time
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
// Unpack parameters
|
// Encode the input
|
||||||
let params = request.parameters;
|
let ctx = Box::new(ctx);
|
||||||
|
|
||||||
// Currently we handle single chunk of text
|
|
||||||
if request.inputs.len() == 1 {
|
|
||||||
match request
|
|
||||||
.inputs
|
|
||||||
.first()
|
|
||||||
.expect("Failed to access the first chunk")
|
|
||||||
{
|
|
||||||
Chunk::Text(text) => {
|
|
||||||
let encoding = self
|
let encoding = self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.encode(&**text, true)
|
.encode(text, true)
|
||||||
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||||
|
|
||||||
let _start = Instant::now();
|
// Submit the request to the backend and retrieve the handle to query its status
|
||||||
let _request_id = self
|
let request_id = self
|
||||||
.inner
|
.inner
|
||||||
.borrow_mut()
|
.borrow_mut()
|
||||||
.as_mut()
|
.as_mut()
|
||||||
@ -79,52 +72,83 @@ impl Backend for TrtLLmBackend {
|
|||||||
params.seed,
|
params.seed,
|
||||||
);
|
);
|
||||||
|
|
||||||
// spawn_blocking(|| {
|
// Stream generated tokens
|
||||||
// // Stream generated tokens
|
// spawn_blocking(move || {
|
||||||
// let num_generated_tokens = self
|
let num_generated_tokens = self
|
||||||
// .inner
|
.inner
|
||||||
// .borrow_mut()
|
.borrow_mut()
|
||||||
// .as_mut()
|
.as_mut()
|
||||||
// .expect("Failed to retrieve pointer to TRTLLM backend")
|
.expect("Failed to retrieve pointer to TRTLLM backend")
|
||||||
// .stream(request_id, ctx, |token, step, is_final| {
|
.stream(ctx, request_id, |ctx, token, step, is_final| {
|
||||||
// // self.tokenizer.decode(&*[token], true).unwrap();
|
// self.tokenizer.decode(&*[token], true).unwrap();
|
||||||
// let token = Token {
|
let sender = ctx.0;
|
||||||
// id: token,
|
let token = Token {
|
||||||
// text: String::from(""),
|
id: token,
|
||||||
// logprob: 1.0f32,
|
text: String::from(""),
|
||||||
// special: false,
|
logprob: 1.0f32,
|
||||||
// };
|
special: false,
|
||||||
//
|
};
|
||||||
// sender
|
|
||||||
// .send(Ok(InferStreamResponse::Intermediate {
|
sender
|
||||||
// token,
|
.send(Ok(InferStreamResponse::Intermediate {
|
||||||
// top_tokens: vec![],
|
token,
|
||||||
// }))
|
top_tokens: vec![],
|
||||||
// .unwrap()
|
}))
|
||||||
// });
|
.unwrap()
|
||||||
//
|
});
|
||||||
// // Notify the end
|
|
||||||
// Ok(InferStreamResponse::End {
|
// Notify the end
|
||||||
// token: Token {
|
let _ = ctx.0.send(Ok(InferStreamResponse::End {
|
||||||
// id: 0,
|
token: Token {
|
||||||
// text: String::from(""),
|
id: 0,
|
||||||
// logprob: 1.0f32,
|
text: String::from(""),
|
||||||
// special: false,
|
logprob: 1.0f32,
|
||||||
// },
|
special: false,
|
||||||
// top_tokens: vec![],
|
},
|
||||||
// generated_text: GeneratedText {
|
top_tokens: vec![],
|
||||||
// text: String::from(""),
|
generated_text: GeneratedText {
|
||||||
// generated_tokens: num_generated_tokens,
|
text: String::from(""),
|
||||||
// finish_reason: FinishReason::EndOfSequenceToken,
|
generated_tokens: num_generated_tokens,
|
||||||
// seed: Some(params.seed),
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
// },
|
seed: Some(params.seed),
|
||||||
// start,
|
},
|
||||||
// queued: Instant::now(),
|
start,
|
||||||
// })
|
queued: Instant::now(),
|
||||||
|
}));
|
||||||
// });
|
// });
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
Chunk::Image(_) => {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for TrtLLmBackend {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||||
|
let (sender, receiver) = mpsc::unbounded_channel();
|
||||||
|
let ctx = GenerationContext(sender);
|
||||||
|
|
||||||
|
// Unpack parameters
|
||||||
|
let params = request.parameters;
|
||||||
|
|
||||||
|
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
|
||||||
|
let input = match request.inputs.len() {
|
||||||
|
0 => Err(InferError::GenerationError("No input provided".into())),
|
||||||
|
1 => Ok(request.inputs.first().unwrap()),
|
||||||
|
_ => Err(InferError::GenerationError(format!(
|
||||||
|
"Unsupported multi-chunks ({}) inference.",
|
||||||
|
request.inputs.len()
|
||||||
|
))),
|
||||||
|
}?;
|
||||||
|
|
||||||
|
// Currently we handle single chunk of text
|
||||||
|
match input {
|
||||||
|
Chunk::Text(text) => {
|
||||||
|
self.infer_text(ctx, &**text, params)?;
|
||||||
|
}
|
||||||
|
Chunk::Image(_) => panic!("Unsupported"),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(UnboundedReceiverStream::new(receiver))
|
Ok(UnboundedReceiverStream::new(receiver))
|
||||||
|
Loading…
Reference in New Issue
Block a user