feat: improve to tokenize too

This commit is contained in:
drbh 2024-07-30 13:48:13 +00:00
parent 62d7be3727
commit 26b954dfd3
2 changed files with 76 additions and 6 deletions

View File

@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ChatTokenizeResponse {
pub(crate) tokenize_response: TokenizeResponse,
pub(crate) templated_text: String,
}
#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);

View File

@ -8,6 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::ChatTokenizeResponse;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -118,22 +119,28 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/templated",
path = "/chat_tokenize",
request_body = ChatRequest,
responses((status = 200, description = "Templated Chat Request", body = Value))
)]
async fn get_templated(
async fn get_chat_tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
max_tokens,
messages,
response_format,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
..
} = req;
@ -193,7 +200,64 @@ async fn get_templated(
}
};
Ok((HeaderMap::new(), Json(inputs)).into_response())
let generate_request = GenerateRequest {
inputs,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: max_tokens,
return_full_text: None,
stop: stop.unwrap_or_default(),
truncate: None,
watermark: false,
details: false,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: _grammar,
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
};
let input = generate_request.inputs.clone();
let encoding = infer.tokenize(generate_request).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text: String =
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens),
templated_text: input,
};
Ok((HeaderMap::new(), Json(resp)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}
#[utoipa::path(
@ -2117,7 +2181,7 @@ async fn start(
}
let info_routes = Router::new()
.route("/", get(health))
.route("/templated", post(get_templated))
.route("/chat_tokenize", post(get_chat_tokenize))
.route("/info", get(get_model_info))
.route("/health", get(health))
.route("/ping", get(health))