feat(trtllm): do not tokenize twice

This commit is contained in:
Morgan Funtowicz 2024-10-21 15:06:54 +02:00
parent 1a3da05f34
commit 8d1c3c8ad4

View File

@ -33,15 +33,9 @@ struct IdentifiableRequest<T> {
inner: T,
}
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
struct ValidGenerateRequestWithTokens {
encoding: Encoding,
inner: ValidGenerateRequest,
}
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequestWithTokens,
request: ValidGenerateRequest,
start: Option<Instant>,
queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
@ -97,12 +91,13 @@ fn executor_status_looper(
if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.inner.parameters;
let stopping_params = &request.inner.stopping_parameters;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
request.encoding.get_ids(),
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
generation_params.top_k as i32,
generation_params.top_p,
@ -343,7 +338,11 @@ impl TensorRtLlmBackendV2 {
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled));
}
@ -359,7 +358,7 @@ impl TensorRtLlmBackendV2 {
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Text(text) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
@ -372,15 +371,7 @@ impl Backend for TensorRtLlmBackendV2 {
&self,
inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
let prompt = Self::validate(&inner)?;
// We encode the prompt in every request context/thread
let encoding = self
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?;
let request = ValidGenerateRequestWithTokens { encoding, inner };
Self::validate(&inner)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
@ -388,7 +379,7 @@ impl Backend for TensorRtLlmBackendV2 {
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.executor.send(GenerationContext {
request,
request: inner,
start: None,
queued,
streamer,