mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-06 17:52:07 +00:00
feat(trtllm): do not tokenize twice
This commit is contained in:
parent
1a3da05f34
commit
8d1c3c8ad4
@ -33,15 +33,9 @@ struct IdentifiableRequest<T> {
|
||||
inner: T,
|
||||
}
|
||||
|
||||
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
|
||||
struct ValidGenerateRequestWithTokens {
|
||||
encoding: Encoding,
|
||||
inner: ValidGenerateRequest,
|
||||
}
|
||||
|
||||
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||
struct GenerationContext {
|
||||
request: ValidGenerateRequestWithTokens,
|
||||
request: ValidGenerateRequest,
|
||||
start: Option<Instant>,
|
||||
queued: Instant,
|
||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
@ -97,12 +91,13 @@ fn executor_status_looper(
|
||||
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
||||
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||
let request = &ctx.request;
|
||||
let generation_params = &request.inner.parameters;
|
||||
let stopping_params = &request.inner.stopping_parameters;
|
||||
let generation_params = &request.parameters;
|
||||
let stopping_params = &request.stopping_parameters;
|
||||
let input_ids = request.input_ids.as_deref();
|
||||
|
||||
// Submit to the TensorRT-LLM executor for scheduling
|
||||
match backend.pin_mut().submit(
|
||||
request.encoding.get_ids(),
|
||||
&input_ids.unwrap(), // This is checked beforehand in validate()
|
||||
stopping_params.max_new_tokens,
|
||||
generation_params.top_k as i32,
|
||||
generation_params.top_p,
|
||||
@ -343,7 +338,11 @@ impl TensorRtLlmBackendV2 {
|
||||
})
|
||||
}
|
||||
|
||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
|
||||
if request.input_ids.is_none() {
|
||||
return Err(ValidationError(UnsupportedModality("No token provided")));
|
||||
}
|
||||
|
||||
if request.top_n_tokens > 1 {
|
||||
return Err(ValidationError(TopNTokensDisabled));
|
||||
}
|
||||
@ -359,7 +358,7 @@ impl TensorRtLlmBackendV2 {
|
||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||
)),
|
||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||
Chunk::Text(text) => Ok(text),
|
||||
Chunk::Text(text) => Ok(()),
|
||||
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||
},
|
||||
}
|
||||
@ -372,15 +371,7 @@ impl Backend for TensorRtLlmBackendV2 {
|
||||
&self,
|
||||
inner: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
let prompt = Self::validate(&inner)?;
|
||||
|
||||
// We encode the prompt in every request context/thread
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?;
|
||||
|
||||
let request = ValidGenerateRequestWithTokens { encoding, inner };
|
||||
Self::validate(&inner)?;
|
||||
|
||||
// Open-up the stream to send tokens
|
||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||
@ -388,7 +379,7 @@ impl Backend for TensorRtLlmBackendV2 {
|
||||
// Send the context to the executor for scheduling
|
||||
let queued = Instant::now();
|
||||
match self.executor.send(GenerationContext {
|
||||
request,
|
||||
request: inner,
|
||||
start: None,
|
||||
queued,
|
||||
streamer,
|
||||
|
Loading…
Reference in New Issue
Block a user