Adding tokenizer route.

This commit is contained in:
Nicolas Patry 2024-01-23 14:49:04 +01:00
parent 98e5faff9d
commit 4f7f617e91
4 changed files with 113 additions and 22 deletions

View File

@ -165,6 +165,28 @@ impl Infer {
)) ))
} }
/// Tokenizer the input
#[instrument(skip_all)]
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.await
.map_err(|err| {
tracing::error!("Tokenization {err}");
err
})?;
// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> { pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {

View File

@ -432,6 +432,18 @@ pub struct Token {
special: bool, special: bool,
} }
#[derive(Debug, Serialize, ToSchema)]
pub struct SimpleToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(example = 0)]
start: usize,
#[schema(example = 2)]
stop: usize,
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason { pub(crate) enum FinishReason {

View File

@ -5,8 +5,8 @@ use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, SimpleToken, StreamDetails,
Token, Validation, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -528,7 +528,7 @@ async fn generate_stream_internal(
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Chat completions",
path = "/v1/chat/completions", path = "/v1/chat/completions",
request_body = ChatRequest, request_body = ChatRequest,
responses( responses(
@ -672,6 +672,52 @@ async fn chat_completions(
} }
} }
/// Tokenize inputs
#[utoipa::path(
post,
tag = "Tokenize",
path = "/tokenize",
request_body = TokenizeRequest,
responses(
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
(status = 404, description = "No tokenizer found", body = ErrorResponse,
example = json ! ({"error": "No fast tokenizer available"})),
)
)]
#[instrument(skip_all)]
async fn tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<GenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let input = req.inputs.clone();
let encoding = infer.tokenize(req).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, (start, stop))| {
let text: String = input.chars().skip(*start).take(stop - start).collect();
SimpleToken {
id,
text,
start: *start,
stop: *stop,
}
})
.collect();
Ok(Json(tokens).into_response())
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}
/// Prometheus metrics scrape endpoint /// Prometheus metrics scrape endpoint
#[utoipa::path( #[utoipa::path(
get, get,
@ -867,6 +913,7 @@ pub async fn run(
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/tokenize", post(tokenize))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
.route("/metrics", get(metrics)); .route("/metrics", get(metrics));

View File

@ -70,12 +70,11 @@ impl Validation {
} }
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
async fn validate_input( pub async fn tokenize(
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.sender { if let Some(sender) = &self.sender {
// Create response channel // Create response channel
@ -88,7 +87,24 @@ impl Validation {
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?; let encoding = response_receiver.await.unwrap()?;
Ok(Some(encoding))
} else {
Ok(None)
}
}
#[instrument(skip(self, inputs))]
async fn validate_input(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
let input_length = encoding.len();
// Get total tokens // Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
@ -346,33 +362,27 @@ fn prepare_input(
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
) -> Result<(String, usize), ValidationError> { ) -> Result<(tokenizers::Encoding, String), ValidationError> {
// Get the number of tokens in the input // Get the number of tokens in the input
let mut encoding = tokenizer let mut encoding = tokenizer
.encode(inputs.clone(), true) .encode(inputs.clone(), true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
// Optionally truncate // Optionally truncate
let (inputs, input_length) = match truncate { if let Some(truncate) = truncate {
// Truncate is some and < encoding length if truncate < encoding.len() {
Some(truncate) if truncate < encoding.len() => {
// truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left); encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
.decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
} }
// Nothing to do }
_ => (inputs, encoding.len()), let inputs = tokenizer
}; .decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((inputs, input_length)) Ok((encoding, inputs))
} }
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, Option<usize>),
oneshot::Sender<Result<(String, usize), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
Span, Span,
); );