From 4f7f617e911c77540f15d6ec2f30c1fa0e57e549 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 23 Jan 2024 14:49:04 +0100 Subject: [PATCH] Adding tokenizer route. --- router/src/infer.rs | 22 +++++++++++++++++ router/src/lib.rs | 12 +++++++++ router/src/server.rs | 53 +++++++++++++++++++++++++++++++++++++--- router/src/validation.rs | 48 ++++++++++++++++++++++-------------- 4 files changed, 113 insertions(+), 22 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 8a9875eb..5f078ba0 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -165,6 +165,28 @@ impl Infer { )) } + /// Tokenizer the input + #[instrument(skip_all)] + pub(crate) async fn tokenize( + &self, + request: GenerateRequest, + ) -> Result, InferError> { + // Tokenize request + let inputs = request.inputs; + let truncate = request.parameters.truncate; + let encoding = self + .validation + .tokenize(inputs, truncate) + .await + .map_err(|err| { + tracing::error!("Tokenization {err}"); + err + })?; + + // Return Encoding + Ok(encoding.map(|(encoding, _)| encoding)) + } + /// Apply the chat template to the chat request #[instrument(skip_all)] pub(crate) fn apply_chat_template(&self, messages: Vec) -> Result { diff --git a/router/src/lib.rs b/router/src/lib.rs index 983079d6..4a0c552d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -432,6 +432,18 @@ pub struct Token { special: bool, } +#[derive(Debug, Serialize, ToSchema)] +pub struct SimpleToken { + #[schema(example = 0)] + id: u32, + #[schema(example = "test")] + text: String, + #[schema(example = 0)] + start: usize, + #[schema(example = 2)] + stop: usize, +} + #[derive(Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] pub(crate) enum FinishReason { diff --git a/router/src/server.rs b/router/src/server.rs index cf1d94a6..365de1ca 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,8 +5,8 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, - HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse, - Token, Validation, + HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, SimpleToken, StreamDetails, + StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -528,7 +528,7 @@ async fn generate_stream_internal( /// Generate tokens #[utoipa::path( post, - tag = "Text Generation Inference", + tag = "Chat completions", path = "/v1/chat/completions", request_body = ChatRequest, responses( @@ -672,6 +672,52 @@ async fn chat_completions( } } +/// Tokenize inputs +#[utoipa::path( + post, + tag = "Tokenize", + path = "/tokenize", + request_body = TokenizeRequest, + responses( + (status = 200, description = "Tokenized ids", body = TokenizeResponse), + (status = 404, description = "No tokenizer found", body = ErrorResponse, + example = json ! ({"error": "No fast tokenizer available"})), + ) + )] +#[instrument(skip_all)] +async fn tokenize( + Extension(infer): Extension, + Json(req): Json, +) -> Result)> { + let input = req.inputs.clone(); + let encoding = infer.tokenize(req).await?; + if let Some(encoding) = encoding { + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, (start, stop))| { + let text: String = input.chars().skip(*start).take(stop - start).collect(); + SimpleToken { + id, + text, + start: *start, + stop: *stop, + } + }) + .collect(); + Ok(Json(tokens).into_response()) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } +} + /// Prometheus metrics scrape endpoint #[utoipa::path( get, @@ -867,6 +913,7 @@ pub async fn run( .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) + .route("/tokenize", post(tokenize)) .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)); diff --git a/router/src/validation.rs b/router/src/validation.rs index 370e9588..e6dbcf81 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -70,12 +70,11 @@ impl Validation { } #[instrument(skip(self, inputs))] - async fn validate_input( + pub async fn tokenize( &self, inputs: String, truncate: Option, - max_new_tokens: Option, - ) -> Result<(String, usize, u32), ValidationError> { + ) -> Result, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -88,7 +87,24 @@ impl Validation { // Await on response channel // Unwrap is safe here - let (inputs, input_length) = response_receiver.await.unwrap()?; + let encoding = response_receiver.await.unwrap()?; + Ok(Some(encoding)) + } else { + Ok(None) + } + } + + #[instrument(skip(self, inputs))] + async fn validate_input( + &self, + inputs: String, + truncate: Option, + max_new_tokens: Option, + ) -> Result<(String, usize, u32), ValidationError> { + // If we have a fast tokenizer + if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { + // Create response channel + let input_length = encoding.len(); // Get total tokens let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { @@ -346,33 +362,27 @@ fn prepare_input( inputs: String, truncate: Option, tokenizer: &Tokenizer, -) -> Result<(String, usize), ValidationError> { +) -> Result<(tokenizers::Encoding, String), ValidationError> { // Get the number of tokens in the input let mut encoding = tokenizer .encode(inputs.clone(), true) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; // Optionally truncate - let (inputs, input_length) = match truncate { - // Truncate is some and < encoding length - Some(truncate) if truncate < encoding.len() => { - // truncate encoding and decode new inputs + if let Some(truncate) = truncate { + if truncate < encoding.len() { encoding.truncate(truncate, 0, TruncationDirection::Left); - let inputs = tokenizer - .decode(encoding.get_ids(), false) - .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - (inputs, encoding.len()) } - // Nothing to do - _ => (inputs, encoding.len()), - }; - - Ok((inputs, input_length)) + } + let inputs = tokenizer + .decode(encoding.get_ids(), false) + .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + Ok((encoding, inputs)) } type TokenizerRequest = ( (String, Option), - oneshot::Sender>, + oneshot::Sender>, Span, );