feat(router): add openAPI schemas

This commit is contained in:
OlivierDehaene 2023-02-02 17:31:52 +01:00
parent b1482d9048
commit 2878c43cc5
11 changed files with 374 additions and 112 deletions

183
Cargo.lock generated
View File

@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.5.17"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
dependencies = [
"async-trait",
"axum-core",
@ -101,8 +101,10 @@ dependencies = [
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
@ -114,9 +116,9 @@ dependencies = [
[[package]]
name = "axum-core"
version = "0.2.9"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
dependencies = [
"async-trait",
"bytes",
@ -124,6 +126,7 @@ dependencies = [
"http",
"http-body",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
@ -207,7 +210,7 @@ dependencies = [
"tar",
"tempfile",
"thiserror",
"zip",
"zip 0.5.13",
"zip-extensions",
]
@ -465,6 +468,15 @@ dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059"
dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-sys"
version = "0.3.7"
@ -867,6 +879,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e"
dependencies = [
"autocfg",
"hashbrown",
"serde",
]
[[package]]
@ -999,9 +1012,9 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]]
name = "matchit"
version = "0.5.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]]
name = "memchr"
@ -1024,6 +1037,16 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -1552,12 +1575,62 @@ dependencies = [
"winreg",
]
[[package]]
name = "rust-embed"
version = "6.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "6.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"shellexpand",
"syn",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "7.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054"
dependencies = [
"sha2",
"walkdir",
]
[[package]]
name = "rustversion"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
[[package]]
name = "ryu"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.20"
@ -1628,6 +1701,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@ -1660,6 +1742,15 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shellexpand"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
dependencies = [
"dirs 4.0.0",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
@ -1845,6 +1936,8 @@ dependencies = [
"tokio-stream",
"tracing",
"tracing-subscriber",
"utoipa",
"utoipa-swagger-ui",
]
[[package]]
@ -1921,7 +2014,7 @@ dependencies = [
"cached-path",
"clap 2.34.0",
"derive_builder",
"dirs",
"dirs 3.0.2",
"esaxx-rs",
"getrandom",
"indicatif 0.15.0",
@ -2234,6 +2327,15 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.8"
@ -2293,6 +2395,46 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utoipa"
version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
dependencies = [
"indexmap",
"serde",
"serde_json",
"utoipa-gen",
]
[[package]]
name = "utoipa-gen"
version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "utoipa-swagger-ui"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2"
dependencies = [
"axum",
"mime_guess",
"regex",
"rust-embed",
"serde",
"serde_json",
"utoipa",
"zip 0.6.4",
]
[[package]]
name = "valuable"
version = "0.1.0"
@ -2317,6 +2459,17 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
dependencies = [
"same-file",
"winapi",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.0"
@ -2589,11 +2742,23 @@ dependencies = [
"time",
]
[[package]]
name = "zip"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
dependencies = [
"byteorder",
"crc32fast",
"crossbeam-utils",
"flate2",
]
[[package]]
name = "zip-extensions"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
dependencies = [
"zip",
"zip 0.5.13",
]

View File

@ -71,13 +71,19 @@ message Batch {
uint32 size = 3;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
string finish_reason = 3;
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}

View File

@ -14,7 +14,7 @@ path = "src/main.rs"
[dependencies]
async-stream = "0.3.3"
axum = { version = "0.5.16", features = ["json", "serde_json"] }
axum = { version = "0.6.4", features = ["json"] }
text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
@ -29,4 +29,6 @@ tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot"
tokio-stream = "0.1.11"
tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -7,8 +7,8 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
StoppingCriteriaParameters,
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -127,7 +127,7 @@ impl Infer {
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.map(|((id, logprob), text)| Token { id, text, logprob })
.collect();
}
// Push last token
@ -282,11 +282,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
}
// Create last Token
let token = Token(
generation.token_id,
generation.token_text,
generation.token_logprob,
);
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
};
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message

View File

@ -1,5 +1,4 @@
/// Text Generation Inference Webserver
mod infer;
mod queue;
pub mod server;
@ -8,45 +7,39 @@ mod validation;
use infer::Infer;
use queue::{Entry, Queue};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validation::Validation;
#[derive(Clone, Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
#[serde(default = "default_top_k")]
pub top_k: i32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default)]
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null")]
pub temperature: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null")]
pub repetition_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null")]
pub top_k: Option<i32>,
#[serde(default)]
#[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")]
pub top_p: Option<f32>,
#[serde(default = "default_do_sample")]
#[schema(default = "false")]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32,
#[serde(default)]
#[schema(max_items = 4, default = "null")]
pub stop: Vec<String>,
#[serde(default)]
#[schema(default = "true")]
pub details: bool,
#[serde(default)]
pub seed: Option<u64>,
}
fn default_temperature() -> f32 {
1.0
}
fn default_repetition_penalty() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
@ -57,10 +50,10 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
repetition_penalty: default_repetition_penalty(),
top_k: default_top_k(),
top_p: default_top_p(),
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
stop: vec![],
@ -69,42 +62,71 @@ fn default_parameters() -> GenerateParameters {
}
}
#[derive(Clone, Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
}
#[derive(Debug, Serialize)]
pub struct Token(u32, String, f32);
#[derive(Debug, Serialize, ToSchema)]
pub struct Token {
id: u32,
text: String,
logprob: f32,
}
#[derive(Serialize)]
#[derive(Serialize, ToSchema)]
pub(crate) enum FinishReason {
Length,
EndOfSequenceToken,
StopSequence,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct Details {
pub finish_reason: String,
pub finish_reason: FinishReason,
pub generated_tokens: u32,
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill: Option<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<Vec<Token>>,
}
#[derive(Serialize)]
#[derive(Serialize, ToSchema)]
pub(crate) struct GenerateResponse {
pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>,
}
#[derive(Serialize)]
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
pub finish_reason: FinishReason,
pub generated_tokens: u32,
pub seed: Option<u64>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
pub generated_text: Option<String>,
pub details: Option<Details>,
pub details: Option<StreamDetails>,
}
#[derive(Serialize)]
pub(crate) struct ErrorResponse {
pub error: String,
#[derive(Serialize, ToSchema)]
pub(crate) enum ErrorType {
#[schema(example = "Request failed during generation")]
GenerationError(String),
#[schema(example = "Model is overloaded")]
Overloaded(String),
#[schema(example = "Input validation error")]
ValidationError(String),
#[schema(example = "Incomplete generation")]
IncompleteGeneration(String),
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse {
pub error: ErrorType,
}

View File

@ -1,8 +1,8 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer,
StreamResponse, Validation,
Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
@ -19,6 +19,8 @@ use tokio::signal;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use tracing::instrument;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
/// Health check method
#[instrument(skip(infer))]
@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
.generate(GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
temperature: 1.0,
repetition_penalty: 1.0,
top_k: 0,
top_p: 1.0,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
do_sample: false,
max_new_tokens: 1,
stop: vec![],
stop: Vec::new(),
details: false,
seed: None,
},
@ -48,6 +50,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
}
/// Generate method
#[utoipa::path(post, path = "/generate", request_body = GenerateRequest)]
#[instrument(
skip(infer),
fields(
@ -76,7 +79,7 @@ async fn generate(
// Token details
let details = match details {
true => Some(Details {
finish_reason: response.generated_text.finish_reason,
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill),
tokens: Some(response.tokens),
@ -133,6 +136,7 @@ async fn generate(
}
/// Generate stream method
#[utoipa::path(post, path = "/generate_stream")]
#[instrument(
skip(infer),
fields(
@ -185,11 +189,9 @@ async fn generate_stream(
} => {
// Token details
let details = match details {
true => Some(Details {
finish_reason: generated_text.finish_reason,
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
prefill: None,
tokens: None,
seed: generated_text.seed,
}),
false => None,
@ -265,6 +267,33 @@ pub async fn run(
validation_workers: usize,
addr: SocketAddr,
) {
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
generate,
generate_stream,
),
components(
schemas(
GenerateRequest,
GenerateParameters,
Token,
GenerateResponse,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
ErrorType
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
)
)]
struct ApiDoc;
// Create state
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let infer = Infer::new(
@ -277,6 +306,7 @@ pub async fn run(
// Create router
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(generate))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
@ -320,6 +350,30 @@ async fn shutdown_signal() {
tracing::info!("signal received, starting graceful shutdown");
}
impl From<i32> for FinishReason {
fn from(finish_reason: i32) -> Self {
let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
match finish_reason {
text_generation_client::FinishReason::Length => FinishReason::Length,
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
}
}
}
impl From<InferError> for ErrorResponse {
fn from(err: InferError) -> Self {
let err_string = err.to_string();
let error = match err {
InferError::GenerationError(_) => ErrorType::GenerationError(err_string),
InferError::Overloaded(_) => ErrorType::Overloaded(err_string),
InferError::ValidationError(_) => ErrorType::ValidationError(err_string),
InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string),
};
ErrorResponse { error }
}
}
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {
@ -330,21 +384,14 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
};
(
status_code,
Json(ErrorResponse {
error: err.to_string(),
}),
)
(status_code, Json(ErrorResponse::from(err)))
}
}
impl From<InferError> for Event {
fn from(err: InferError) -> Self {
Event::default()
.json_data(ErrorResponse {
error: err.to_string(),
})
.json_data(ErrorResponse::from(err))
.unwrap()
}
}

View File

@ -110,30 +110,58 @@ fn validate(
max_input_length: usize,
rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> {
if request.parameters.temperature <= 0.0 {
let GenerateParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
seed,
..
} = request.parameters;
let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
return Err(ValidationError::Temperature);
}
if request.parameters.repetition_penalty <= 0.0 {
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
if repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty);
}
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
let top_p = top_p.unwrap_or(1.0);
if top_p <= 0.0 || top_p > 1.0 {
return Err(ValidationError::TopP);
}
if request.parameters.top_k < 0 {
return Err(ValidationError::TopK);
}
if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS {
// Different because the proto default value is 0 while it is not a valid value
// for the user
let top_k: u32 = match top_k {
None => Ok(0),
Some(top_k) => {
if top_k <= 0 {
return Err(ValidationError::TopK);
}
Ok(top_k as u32)
}
}?;
if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
}
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
if stop_sequences.len() > MAX_STOP_SEQUENCES {
return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES,
request.parameters.stop.len(),
stop_sequences.len(),
));
}
// If seed is None, assign a random one
let seed = match request.parameters.seed {
let seed = match seed {
None => rng.gen(),
Some(seed) => seed,
};
@ -147,21 +175,10 @@ fn validate(
Err(ValidationError::InputLength(input_length, max_input_length))
} else {
// Return ValidGenerateRequest
let GenerateParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
..
} = request.parameters;
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
top_k: top_k as u32,
top_k,
top_p,
do_sample,
seed,
@ -206,7 +223,7 @@ pub enum ValidationError {
TopP,
#[error("top_k must be strictly positive")]
TopK,
#[error("max_new_tokens must be <= {0}")]
#[error("max_new_tokens must be strictly positive and <= {0}")]
MaxNewTokens(u32),
#[error("inputs must have less than {1} tokens. Given: {0}")]
InputLength(usize, usize),

View File

@ -9,6 +9,7 @@ from text_generation.utils import (
StopSequenceCriteria,
StoppingCriteria,
LocalEntryNotFoundError,
FinishReason
)
@ -24,13 +25,13 @@ def test_stop_sequence_criteria():
def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, "stop_sequence")
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, "eos_token")
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
def test_stopping_criteria_max():
@ -39,7 +40,7 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, "length")
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files():

View File

@ -7,6 +7,7 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
class Batch(ABC):
@ -38,7 +39,7 @@ class Batch(ABC):
class GeneratedText:
text: str
generated_tokens: int
finish_reason: str
finish_reason: FinishReason
seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText:

View File

@ -24,6 +24,7 @@ from transformers.generation.logits_process import (
)
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
@ -129,15 +130,15 @@ class StoppingCriteria:
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, "length"
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, "eos_token"
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, "stop_sequence"
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None