mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Adding tokenizer route.
This commit is contained in:
parent
98e5faff9d
commit
4f7f617e91
@ -165,6 +165,28 @@ impl Infer {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tokenizer the input
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn tokenize(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||||
|
// Tokenize request
|
||||||
|
let inputs = request.inputs;
|
||||||
|
let truncate = request.parameters.truncate;
|
||||||
|
let encoding = self
|
||||||
|
.validation
|
||||||
|
.tokenize(inputs, truncate)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
tracing::error!("Tokenization {err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Return Encoding
|
||||||
|
Ok(encoding.map(|(encoding, _)| encoding))
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||||
|
@ -432,6 +432,18 @@ pub struct Token {
|
|||||||
special: bool,
|
special: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct SimpleToken {
|
||||||
|
#[schema(example = 0)]
|
||||||
|
id: u32,
|
||||||
|
#[schema(example = "test")]
|
||||||
|
text: String,
|
||||||
|
#[schema(example = 0)]
|
||||||
|
start: usize,
|
||||||
|
#[schema(example = 2)]
|
||||||
|
stop: usize,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
pub(crate) enum FinishReason {
|
pub(crate) enum FinishReason {
|
||||||
|
@ -5,8 +5,8 @@ use crate::validation::ValidationError;
|
|||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
||||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||||
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse,
|
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, SimpleToken, StreamDetails,
|
||||||
Token, Validation,
|
StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -528,7 +528,7 @@ async fn generate_stream_internal(
|
|||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Chat completions",
|
||||||
path = "/v1/chat/completions",
|
path = "/v1/chat/completions",
|
||||||
request_body = ChatRequest,
|
request_body = ChatRequest,
|
||||||
responses(
|
responses(
|
||||||
@ -672,6 +672,52 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tokenize inputs
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Tokenize",
|
||||||
|
path = "/tokenize",
|
||||||
|
request_body = TokenizeRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
|
||||||
|
(status = 404, description = "No tokenizer found", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No fast tokenizer available"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn tokenize(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Json(req): Json<GenerateRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let input = req.inputs.clone();
|
||||||
|
let encoding = infer.tokenize(req).await?;
|
||||||
|
if let Some(encoding) = encoding {
|
||||||
|
let tokens: Vec<SimpleToken> = encoding
|
||||||
|
.get_ids()
|
||||||
|
.iter()
|
||||||
|
.zip(encoding.get_offsets())
|
||||||
|
.map(|(&id, (start, stop))| {
|
||||||
|
let text: String = input.chars().skip(*start).take(stop - start).collect();
|
||||||
|
SimpleToken {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
start: *start,
|
||||||
|
stop: *stop,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(Json(tokens).into_response())
|
||||||
|
} else {
|
||||||
|
Err((
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
||||||
|
error_type: "no fast tokenizer".to_string(),
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
@ -867,6 +913,7 @@ pub async fn run(
|
|||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/tokenize", post(tokenize))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
.route("/metrics", get(metrics));
|
.route("/metrics", get(metrics));
|
||||||
|
@ -70,12 +70,11 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self, inputs))]
|
#[instrument(skip(self, inputs))]
|
||||||
async fn validate_input(
|
pub async fn tokenize(
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
|
||||||
) -> Result<(String, usize, u32), ValidationError> {
|
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
if let Some(sender) = &self.sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -88,7 +87,24 @@ impl Validation {
|
|||||||
|
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
let encoding = response_receiver.await.unwrap()?;
|
||||||
|
Ok(Some(encoding))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self, inputs))]
|
||||||
|
async fn validate_input(
|
||||||
|
&self,
|
||||||
|
inputs: String,
|
||||||
|
truncate: Option<usize>,
|
||||||
|
max_new_tokens: Option<u32>,
|
||||||
|
) -> Result<(String, usize, u32), ValidationError> {
|
||||||
|
// If we have a fast tokenizer
|
||||||
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
|
// Create response channel
|
||||||
|
let input_length = encoding.len();
|
||||||
|
|
||||||
// Get total tokens
|
// Get total tokens
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
@ -346,33 +362,27 @@ fn prepare_input(
|
|||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
) -> Result<(String, usize), ValidationError> {
|
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
let mut encoding = tokenizer
|
let mut encoding = tokenizer
|
||||||
.encode(inputs.clone(), true)
|
.encode(inputs.clone(), true)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
// Optionally truncate
|
// Optionally truncate
|
||||||
let (inputs, input_length) = match truncate {
|
if let Some(truncate) = truncate {
|
||||||
// Truncate is some and < encoding length
|
if truncate < encoding.len() {
|
||||||
Some(truncate) if truncate < encoding.len() => {
|
|
||||||
// truncate encoding and decode new inputs
|
|
||||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||||
|
}
|
||||||
|
}
|
||||||
let inputs = tokenizer
|
let inputs = tokenizer
|
||||||
.decode(encoding.get_ids(), false)
|
.decode(encoding.get_ids(), false)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
(inputs, encoding.len())
|
Ok((encoding, inputs))
|
||||||
}
|
|
||||||
// Nothing to do
|
|
||||||
_ => (inputs, encoding.len()),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((inputs, input_length))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, Option<usize>),
|
(String, Option<usize>),
|
||||||
oneshot::Sender<Result<(String, usize), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user