mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Adding tokenizer route.
This commit is contained in:
parent
98e5faff9d
commit
4f7f617e91
@ -165,6 +165,28 @@ impl Infer {
|
||||
))
|
||||
}
|
||||
|
||||
/// Tokenizer the input
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn tokenize(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||
// Tokenize request
|
||||
let inputs = request.inputs;
|
||||
let truncate = request.parameters.truncate;
|
||||
let encoding = self
|
||||
.validation
|
||||
.tokenize(inputs, truncate)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
tracing::error!("Tokenization {err}");
|
||||
err
|
||||
})?;
|
||||
|
||||
// Return Encoding
|
||||
Ok(encoding.map(|(encoding, _)| encoding))
|
||||
}
|
||||
|
||||
/// Apply the chat template to the chat request
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||
|
@ -432,6 +432,18 @@ pub struct Token {
|
||||
special: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct SimpleToken {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
#[schema(example = 0)]
|
||||
start: usize,
|
||||
#[schema(example = 2)]
|
||||
stop: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
pub(crate) enum FinishReason {
|
||||
|
@ -5,8 +5,8 @@ use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse,
|
||||
Token, Validation,
|
||||
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
@ -528,7 +528,7 @@ async fn generate_stream_internal(
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
tag = "Chat completions",
|
||||
path = "/v1/chat/completions",
|
||||
request_body = ChatRequest,
|
||||
responses(
|
||||
@ -672,6 +672,52 @@ async fn chat_completions(
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokenize inputs
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Tokenize",
|
||||
path = "/tokenize",
|
||||
request_body = TokenizeRequest,
|
||||
responses(
|
||||
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
|
||||
(status = 404, description = "No tokenizer found", body = ErrorResponse,
|
||||
example = json ! ({"error": "No fast tokenizer available"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip_all)]
|
||||
async fn tokenize(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let input = req.inputs.clone();
|
||||
let encoding = infer.tokenize(req).await?;
|
||||
if let Some(encoding) = encoding {
|
||||
let tokens: Vec<SimpleToken> = encoding
|
||||
.get_ids()
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, (start, stop))| {
|
||||
let text: String = input.chars().skip(*start).take(stop - start).collect();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
start: *start,
|
||||
stop: *stop,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
Ok(Json(tokens).into_response())
|
||||
} else {
|
||||
Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ErrorResponse {
|
||||
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
||||
error_type: "no fast tokenizer".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Prometheus metrics scrape endpoint
|
||||
#[utoipa::path(
|
||||
get,
|
||||
@ -867,6 +913,7 @@ pub async fn run(
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/tokenize", post(tokenize))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
.route("/metrics", get(metrics));
|
||||
|
@ -70,12 +70,11 @@ impl Validation {
|
||||
}
|
||||
|
||||
#[instrument(skip(self, inputs))]
|
||||
async fn validate_input(
|
||||
pub async fn tokenize(
|
||||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(String, usize, u32), ValidationError> {
|
||||
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
@ -88,7 +87,24 @@ impl Validation {
|
||||
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
||||
let encoding = response_receiver.await.unwrap()?;
|
||||
Ok(Some(encoding))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, inputs))]
|
||||
async fn validate_input(
|
||||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(String, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
// Create response channel
|
||||
let input_length = encoding.len();
|
||||
|
||||
// Get total tokens
|
||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||
@ -346,33 +362,27 @@ fn prepare_input(
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
) -> Result<(String, usize), ValidationError> {
|
||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||
// Get the number of tokens in the input
|
||||
let mut encoding = tokenizer
|
||||
.encode(inputs.clone(), true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
// Optionally truncate
|
||||
let (inputs, input_length) = match truncate {
|
||||
// Truncate is some and < encoding length
|
||||
Some(truncate) if truncate < encoding.len() => {
|
||||
// truncate encoding and decode new inputs
|
||||
if let Some(truncate) = truncate {
|
||||
if truncate < encoding.len() {
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
let inputs = tokenizer
|
||||
.decode(encoding.get_ids(), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
(inputs, encoding.len())
|
||||
}
|
||||
// Nothing to do
|
||||
_ => (inputs, encoding.len()),
|
||||
};
|
||||
|
||||
Ok((inputs, input_length))
|
||||
}
|
||||
let inputs = tokenizer
|
||||
.decode(encoding.get_ids(), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
Ok((encoding, inputs))
|
||||
}
|
||||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
oneshot::Sender<Result<(String, usize), ValidationError>>,
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user