mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: improve to tokenize too
This commit is contained in:
parent
62d7be3727
commit
26b954dfd3
@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse {
|
||||
pub details: Option<Details>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct ChatTokenizeResponse {
|
||||
pub(crate) tokenize_response: TokenizeResponse,
|
||||
pub(crate) templated_text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[serde(transparent)]
|
||||
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||
|
@ -8,6 +8,7 @@ use crate::kserve::{
|
||||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::ChatTokenizeResponse;
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
@ -118,22 +119,28 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/templated",
|
||||
path = "/chat_tokenize",
|
||||
request_body = ChatRequest,
|
||||
responses((status = 200, description = "Templated Chat Request", body = Value))
|
||||
)]
|
||||
async fn get_templated(
|
||||
async fn get_chat_tokenize(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
response_format,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
..
|
||||
} = req;
|
||||
|
||||
@ -193,7 +200,64 @@ async fn get_templated(
|
||||
}
|
||||
};
|
||||
|
||||
Ok((HeaderMap::new(), Json(inputs)).into_response())
|
||||
let generate_request = GenerateRequest {
|
||||
inputs,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty: None,
|
||||
frequency_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens: max_tokens,
|
||||
return_full_text: None,
|
||||
stop: stop.unwrap_or_default(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: _grammar,
|
||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
};
|
||||
|
||||
let input = generate_request.inputs.clone();
|
||||
let encoding = infer.tokenize(generate_request).await?;
|
||||
if let Some(encoding) = encoding {
|
||||
let tokens: Vec<SimpleToken> = encoding
|
||||
.get_ids()
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text: String =
|
||||
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
start,
|
||||
stop,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resp = ChatTokenizeResponse {
|
||||
tokenize_response: TokenizeResponse(tokens),
|
||||
templated_text: input,
|
||||
};
|
||||
Ok((HeaderMap::new(), Json(resp)))
|
||||
} else {
|
||||
Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ErrorResponse {
|
||||
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
||||
error_type: "no fast tokenizer".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
@ -2117,7 +2181,7 @@ async fn start(
|
||||
}
|
||||
let info_routes = Router::new()
|
||||
.route("/", get(health))
|
||||
.route("/templated", post(get_templated))
|
||||
.route("/chat_tokenize", post(get_chat_tokenize))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
|
Loading…
Reference in New Issue
Block a user