mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-06 17:52:07 +00:00
feat(trtllm): do not tokenize twice
This commit is contained in:
parent
1a3da05f34
commit
8d1c3c8ad4
@ -33,15 +33,9 @@ struct IdentifiableRequest<T> {
|
|||||||
inner: T,
|
inner: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
|
|
||||||
struct ValidGenerateRequestWithTokens {
|
|
||||||
encoding: Encoding,
|
|
||||||
inner: ValidGenerateRequest,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||||
struct GenerationContext {
|
struct GenerationContext {
|
||||||
request: ValidGenerateRequestWithTokens,
|
request: ValidGenerateRequest,
|
||||||
start: Option<Instant>,
|
start: Option<Instant>,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
@ -97,12 +91,13 @@ fn executor_status_looper(
|
|||||||
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
||||||
// Submit all the request to the executor and move the context to the in-flight tracker
|
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||||
let request = &ctx.request;
|
let request = &ctx.request;
|
||||||
let generation_params = &request.inner.parameters;
|
let generation_params = &request.parameters;
|
||||||
let stopping_params = &request.inner.stopping_parameters;
|
let stopping_params = &request.stopping_parameters;
|
||||||
|
let input_ids = request.input_ids.as_deref();
|
||||||
|
|
||||||
// Submit to the TensorRT-LLM executor for scheduling
|
// Submit to the TensorRT-LLM executor for scheduling
|
||||||
match backend.pin_mut().submit(
|
match backend.pin_mut().submit(
|
||||||
request.encoding.get_ids(),
|
&input_ids.unwrap(), // This is checked beforehand in validate()
|
||||||
stopping_params.max_new_tokens,
|
stopping_params.max_new_tokens,
|
||||||
generation_params.top_k as i32,
|
generation_params.top_k as i32,
|
||||||
generation_params.top_p,
|
generation_params.top_p,
|
||||||
@ -343,7 +338,11 @@ impl TensorRtLlmBackendV2 {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
|
||||||
|
if request.input_ids.is_none() {
|
||||||
|
return Err(ValidationError(UnsupportedModality("No token provided")));
|
||||||
|
}
|
||||||
|
|
||||||
if request.top_n_tokens > 1 {
|
if request.top_n_tokens > 1 {
|
||||||
return Err(ValidationError(TopNTokensDisabled));
|
return Err(ValidationError(TopNTokensDisabled));
|
||||||
}
|
}
|
||||||
@ -359,7 +358,7 @@ impl TensorRtLlmBackendV2 {
|
|||||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
)),
|
)),
|
||||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
Chunk::Text(text) => Ok(text),
|
Chunk::Text(text) => Ok(()),
|
||||||
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -372,15 +371,7 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
&self,
|
&self,
|
||||||
inner: ValidGenerateRequest,
|
inner: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
let prompt = Self::validate(&inner)?;
|
Self::validate(&inner)?;
|
||||||
|
|
||||||
// We encode the prompt in every request context/thread
|
|
||||||
let encoding = self
|
|
||||||
.tokenizer
|
|
||||||
.encode(prompt.as_str(), true)
|
|
||||||
.map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?;
|
|
||||||
|
|
||||||
let request = ValidGenerateRequestWithTokens { encoding, inner };
|
|
||||||
|
|
||||||
// Open-up the stream to send tokens
|
// Open-up the stream to send tokens
|
||||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||||
@ -388,7 +379,7 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
// Send the context to the executor for scheduling
|
// Send the context to the executor for scheduling
|
||||||
let queued = Instant::now();
|
let queued = Instant::now();
|
||||||
match self.executor.send(GenerationContext {
|
match self.executor.send(GenerationContext {
|
||||||
request,
|
request: inner,
|
||||||
start: None,
|
start: None,
|
||||||
queued,
|
queued,
|
||||||
streamer,
|
streamer,
|
||||||
|
Loading…
Reference in New Issue
Block a user