From 26b954dfd30f5a54a74835a1e6036dc47723e077 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 30 Jul 2024 13:48:13 +0000 Subject: [PATCH] feat: improve to tokenize too --- router/src/lib.rs | 6 ++++ router/src/server.rs | 76 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 14bb8270..386b0556 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse { pub details: Option
, } +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatTokenizeResponse { + pub(crate) tokenize_response: TokenizeResponse, + pub(crate) templated_text: String, +} + #[derive(Serialize, ToSchema)] #[serde(transparent)] pub(crate) struct TokenizeResponse(Vec); diff --git a/router/src/server.rs b/router/src/server.rs index d718513f..8088beb2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,6 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; +use crate::ChatTokenizeResponse; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -118,22 +119,28 @@ async fn get_model_info(info: Extension) -> Json { #[utoipa::path( post, tag = "Text Generation Inference", - path = "/templated", + path = "/chat_tokenize", request_body = ChatRequest, responses((status = 200, description = "Templated Chat Request", body = Value)) )] -async fn get_templated( +async fn get_chat_tokenize( Extension(infer): Extension, Json(req): Json, -) -> Result)> { +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { metrics::counter!("tgi_request_count").increment(1); let ChatRequest { + model, + max_tokens, messages, - response_format, + seed, + stop, + stream, tools, tool_choice, tool_prompt, + temperature, + response_format, .. } = req; @@ -193,7 +200,64 @@ async fn get_templated( } }; - Ok((HeaderMap::new(), Json(inputs)).into_response()) + let generate_request = GenerateRequest { + inputs, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty: None, + frequency_penalty: None, + top_k: None, + top_p: None, + typical_p: None, + do_sample: true, + max_new_tokens: max_tokens, + return_full_text: None, + stop: stop.unwrap_or_default(), + truncate: None, + watermark: false, + details: false, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + grammar: _grammar, + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), + }, + }; + + let input = generate_request.inputs.clone(); + let encoding = infer.tokenize(generate_request).await?; + if let Some(encoding) = encoding { + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text: String = + String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); + + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } } #[utoipa::path( @@ -2117,7 +2181,7 @@ async fn start( } let info_routes = Router::new() .route("/", get(health)) - .route("/templated", post(get_templated)) + .route("/chat_tokenize", post(get_chat_tokenize)) .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health))