feat(trtllm): do not tokenize twice

This commit is contained in:
Morgan Funtowicz 2024-10-21 15:06:54 +02:00
parent 1a3da05f34
commit 8d1c3c8ad4

View File

@ -33,15 +33,9 @@ struct IdentifiableRequest<T> {
inner: T, inner: T,
} }
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
struct ValidGenerateRequestWithTokens {
encoding: Encoding,
inner: ValidGenerateRequest,
}
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens /// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext { struct GenerationContext {
request: ValidGenerateRequestWithTokens, request: ValidGenerateRequest,
start: Option<Instant>, start: Option<Instant>,
queued: Instant, queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>, streamer: UnboundedSender<InferResult<InferStreamResponse>>,
@ -97,12 +91,13 @@ fn executor_status_looper(
if let Some(mut ctx) = waiting_requests.blocking_recv() { if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker // Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request; let request = &ctx.request;
let generation_params = &request.inner.parameters; let generation_params = &request.parameters;
let stopping_params = &request.inner.stopping_parameters; let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling // Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit( match backend.pin_mut().submit(
request.encoding.get_ids(), &input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens, stopping_params.max_new_tokens,
generation_params.top_k as i32, generation_params.top_k as i32,
generation_params.top_p, generation_params.top_p,
@ -343,7 +338,11 @@ impl TensorRtLlmBackendV2 {
}) })
} }
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 { if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled)); return Err(ValidationError(TopNTokensDisabled));
} }
@ -359,7 +358,7 @@ impl TensorRtLlmBackendV2 {
"TensorRT-LLM backend don't support multi-chunk".into(), "TensorRT-LLM backend don't support multi-chunk".into(),
)), )),
1 => match request.inputs.first().expect("Single item-chunk") { 1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text), Chunk::Text(text) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))), Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
}, },
} }
@ -372,15 +371,7 @@ impl Backend for TensorRtLlmBackendV2 {
&self, &self,
inner: ValidGenerateRequest, inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
let prompt = Self::validate(&inner)?; Self::validate(&inner)?;
// We encode the prompt in every request context/thread
let encoding = self
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?;
let request = ValidGenerateRequestWithTokens { encoding, inner };
// Open-up the stream to send tokens // Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>(); let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
@ -388,7 +379,7 @@ impl Backend for TensorRtLlmBackendV2 {
// Send the context to the executor for scheduling // Send the context to the executor for scheduling
let queued = Instant::now(); let queued = Instant::now();
match self.executor.send(GenerationContext { match self.executor.send(GenerationContext {
request, request: inner,
start: None, start: None,
queued, queued,
streamer, streamer,