diff --git a/Cargo.lock b/Cargo.lock index 5b671fa8..d006267e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.5.17" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" +checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc" dependencies = [ "async-trait", "axum-core", @@ -101,8 +101,10 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", + "rustversion", "serde", "serde_json", + "serde_path_to_error", "serde_urlencoded", "sync_wrapper", "tokio", @@ -114,9 +116,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.2.9" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc" +checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34" dependencies = [ "async-trait", "bytes", @@ -124,6 +126,7 @@ dependencies = [ "http", "http-body", "mime", + "rustversion", "tower-layer", "tower-service", ] @@ -207,7 +210,7 @@ dependencies = [ "tar", "tempfile", "thiserror", - "zip", + "zip 0.5.13", "zip-extensions", ] @@ -465,6 +468,15 @@ dependencies = [ "dirs-sys", ] +[[package]] +name = "dirs" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +dependencies = [ + "dirs-sys", +] + [[package]] name = "dirs-sys" version = "0.3.7" @@ -867,6 +879,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" dependencies = [ "autocfg", "hashbrown", + "serde", ] [[package]] @@ -999,9 +1012,9 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" [[package]] name = "matchit" -version = "0.5.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" [[package]] name = "memchr" @@ -1024,6 +1037,16 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1552,12 +1575,62 @@ dependencies = [ "winreg", ] +[[package]] +name = "rust-embed" +version = "6.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1" +dependencies = [ + "rust-embed-impl", + "rust-embed-utils", + "walkdir", +] + +[[package]] +name = "rust-embed-impl" +version = "6.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d" +dependencies = [ + "proc-macro2", + "quote", + "rust-embed-utils", + "shellexpand", + "syn", + "walkdir", +] + +[[package]] +name = "rust-embed-utils" +version = "7.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054" +dependencies = [ + "sha2", + "walkdir", +] + +[[package]] +name = "rustversion" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" + [[package]] name = "ryu" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.20" @@ -1628,6 +1701,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1660,6 +1742,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shellexpand" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4" +dependencies = [ + "dirs 4.0.0", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -1845,6 +1936,8 @@ dependencies = [ "tokio-stream", "tracing", "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", ] [[package]] @@ -1921,7 +2014,7 @@ dependencies = [ "cached-path", "clap 2.34.0", "derive_builder", - "dirs", + "dirs 3.0.2", "esaxx-rs", "getrandom", "indicatif 0.15.0", @@ -2234,6 +2327,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.8" @@ -2293,6 +2395,46 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utoipa" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c" +dependencies = [ + "indexmap", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-gen" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "utoipa-swagger-ui" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2" +dependencies = [ + "axum", + "mime_guess", + "regex", + "rust-embed", + "serde", + "serde_json", + "utoipa", + "zip 0.6.4", +] + [[package]] name = "valuable" version = "0.1.0" @@ -2317,6 +2459,17 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +dependencies = [ + "same-file", + "winapi", + "winapi-util", +] + [[package]] name = "want" version = "0.3.0" @@ -2589,11 +2742,23 @@ dependencies = [ "time", ] +[[package]] +name = "zip" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", + "flate2", +] + [[package]] name = "zip-extensions" version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14" dependencies = [ - "zip", + "zip 0.5.13", ] diff --git a/proto/generate.proto b/proto/generate.proto index 098df9c5..0c4f9626 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -71,13 +71,19 @@ message Batch { uint32 size = 3; } +enum FinishReason { + FINISH_REASON_LENGTH = 0; + FINISH_REASON_EOS_TOKEN = 1; + FINISH_REASON_STOP_SEQUENCE = 2; +} + message GeneratedText { /// Output string text = 1; /// Number of generated tokens uint32 generated_tokens = 2; /// Finish reason - string finish_reason = 3; + FinishReason finish_reason = 3; /// Seed optional uint64 seed = 4; } diff --git a/router/Cargo.toml b/router/Cargo.toml index 3abbc80b..56d6b4f5 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -14,7 +14,7 @@ path = "src/main.rs" [dependencies] async-stream = "0.3.3" -axum = { version = "0.5.16", features = ["json", "serde_json"] } +axum = { version = "0.6.4", features = ["json"] } text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" @@ -29,4 +29,6 @@ tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot" tokio-stream = "0.1.11" tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["json"] } +utoipa = { version = "3.0.1", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index e0546b16..2767c605 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,8 +7,8 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request, - StoppingCriteriaParameters, + Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, + Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/infer.rs b/router/src/infer.rs index 3661b0e0..159b7ca7 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -127,7 +127,7 @@ impl Infer { .into_iter() .zip(tokens.logprobs.into_iter()) .zip(tokens.texts.into_iter()) - .map(|((id, logprob), text)| Token(id, text, logprob)) + .map(|((id, logprob), text)| Token { id, text, logprob }) .collect(); } // Push last token @@ -282,11 +282,11 @@ fn send_generations(generations: Vec, entries: &mut IntMap, + #[serde(default)] + #[schema(exclusive_minimum = 0.0, nullable = true, default = "null")] + pub repetition_penalty: Option, + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null")] + pub top_k: Option, + #[serde(default)] + #[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")] + pub top_p: Option, #[serde(default = "default_do_sample")] + #[schema(default = "false")] pub do_sample: bool, #[serde(default = "default_max_new_tokens")] + #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] + #[schema(max_items = 4, default = "null")] pub stop: Vec, #[serde(default)] + #[schema(default = "true")] pub details: bool, #[serde(default)] pub seed: Option, } -fn default_temperature() -> f32 { - 1.0 -} -fn default_repetition_penalty() -> f32 { - 1.0 -} - -fn default_top_k() -> i32 { - 0 -} - -fn default_top_p() -> f32 { - 1.0 -} - fn default_do_sample() -> bool { false } @@ -57,10 +50,10 @@ fn default_max_new_tokens() -> u32 { fn default_parameters() -> GenerateParameters { GenerateParameters { - temperature: default_temperature(), - repetition_penalty: default_repetition_penalty(), - top_k: default_top_k(), - top_p: default_top_p(), + temperature: None, + repetition_penalty: None, + top_k: None, + top_p: None, do_sample: default_do_sample(), max_new_tokens: default_max_new_tokens(), stop: vec![], @@ -69,42 +62,71 @@ fn default_parameters() -> GenerateParameters { } } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateRequest { + #[schema(example = "My name is Olivier and I")] pub inputs: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, } -#[derive(Debug, Serialize)] -pub struct Token(u32, String, f32); +#[derive(Debug, Serialize, ToSchema)] +pub struct Token { + id: u32, + text: String, + logprob: f32, +} -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] +pub(crate) enum FinishReason { + Length, + EndOfSequenceToken, + StopSequence, +} + +#[derive(Serialize, ToSchema)] pub(crate) struct Details { - pub finish_reason: String, + pub finish_reason: FinishReason, pub generated_tokens: u32, pub seed: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub prefill: Option>, - #[serde(skip_serializing_if = "Option::is_none")] pub tokens: Option>, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub(crate) struct GenerateResponse { pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] +pub(crate) struct StreamDetails { + pub finish_reason: FinishReason, + pub generated_tokens: u32, + pub seed: Option, +} + +#[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { pub token: Token, pub generated_text: Option, - pub details: Option
, + pub details: Option, } -#[derive(Serialize)] -pub(crate) struct ErrorResponse { - pub error: String, +#[derive(Serialize, ToSchema)] +pub(crate) enum ErrorType { + #[schema(example = "Request failed during generation")] + GenerationError(String), + #[schema(example = "Model is overloaded")] + Overloaded(String), + #[schema(example = "Input validation error")] + ValidationError(String), + #[schema(example = "Incomplete generation")] + IncompleteGeneration(String), +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct ErrorResponse { + pub error: ErrorType, } diff --git a/router/src/server.rs b/router/src/server.rs index f79644e3..fd77a802 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,8 +1,8 @@ /// HTTP Server logic use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, - StreamResponse, Validation, + Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest, + GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; @@ -19,6 +19,8 @@ use tokio::signal; use tokio::time::Instant; use tokio_stream::StreamExt; use tracing::instrument; +use utoipa::OpenApi; +use utoipa_swagger_ui::SwaggerUi; /// Health check method #[instrument(skip(infer))] @@ -32,13 +34,13 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json) -> Result<(), (StatusCode, Json Some(Details { - finish_reason: response.generated_text.finish_reason, + finish_reason: FinishReason::from(response.generated_text.finish_reason), generated_tokens: response.generated_text.generated_tokens, prefill: Some(response.prefill), tokens: Some(response.tokens), @@ -133,6 +136,7 @@ async fn generate( } /// Generate stream method +#[utoipa::path(post, path = "/generate_stream")] #[instrument( skip(infer), fields( @@ -185,11 +189,9 @@ async fn generate_stream( } => { // Token details let details = match details { - true => Some(Details { - finish_reason: generated_text.finish_reason, + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), generated_tokens: generated_text.generated_tokens, - prefill: None, - tokens: None, seed: generated_text.seed, }), false => None, @@ -265,6 +267,33 @@ pub async fn run( validation_workers: usize, addr: SocketAddr, ) { + // OpenAPI documentation + #[derive(OpenApi)] + #[openapi( + paths( + generate, + generate_stream, + ), + components( + schemas( + GenerateRequest, + GenerateParameters, + Token, + GenerateResponse, + Details, + FinishReason, + StreamResponse, + StreamDetails, + ErrorResponse, + ErrorType + ) + ), + tags( + (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") + ) + )] + struct ApiDoc; + // Create state let validation = Validation::new(validation_workers, tokenizer, max_input_length); let infer = Infer::new( @@ -277,6 +306,7 @@ pub async fn run( // Create router let app = Router::new() + .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .route("/", post(generate)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) @@ -320,6 +350,30 @@ async fn shutdown_signal() { tracing::info!("signal received, starting graceful shutdown"); } +impl From for FinishReason { + fn from(finish_reason: i32) -> Self { + let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap(); + match finish_reason { + text_generation_client::FinishReason::Length => FinishReason::Length, + text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence, + } + } +} + +impl From for ErrorResponse { + fn from(err: InferError) -> Self { + let err_string = err.to_string(); + let error = match err { + InferError::GenerationError(_) => ErrorType::GenerationError(err_string), + InferError::Overloaded(_) => ErrorType::Overloaded(err_string), + InferError::ValidationError(_) => ErrorType::ValidationError(err_string), + InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string), + }; + ErrorResponse { error } + } +} + /// Convert to Axum supported formats impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { @@ -330,21 +384,14 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, }; - ( - status_code, - Json(ErrorResponse { - error: err.to_string(), - }), - ) + (status_code, Json(ErrorResponse::from(err))) } } impl From for Event { fn from(err: InferError) -> Self { Event::default() - .json_data(ErrorResponse { - error: err.to_string(), - }) + .json_data(ErrorResponse::from(err)) .unwrap() } } diff --git a/router/src/validation.rs b/router/src/validation.rs index 09220823..b7b33a19 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -110,30 +110,58 @@ fn validate( max_input_length: usize, rng: &mut ThreadRng, ) -> Result { - if request.parameters.temperature <= 0.0 { + let GenerateParameters { + temperature, + repetition_penalty, + top_k, + top_p, + do_sample, + max_new_tokens, + stop: stop_sequences, + seed, + .. + } = request.parameters; + + let temperature = temperature.unwrap_or(1.0); + if temperature <= 0.0 { return Err(ValidationError::Temperature); } - if request.parameters.repetition_penalty <= 0.0 { + + let repetition_penalty = repetition_penalty.unwrap_or(1.0); + if repetition_penalty <= 0.0 { return Err(ValidationError::RepetitionPenalty); } - if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { + + let top_p = top_p.unwrap_or(1.0); + if top_p <= 0.0 || top_p > 1.0 { return Err(ValidationError::TopP); } - if request.parameters.top_k < 0 { - return Err(ValidationError::TopK); - } - if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS { + + // Different because the proto default value is 0 while it is not a valid value + // for the user + let top_k: u32 = match top_k { + None => Ok(0), + Some(top_k) => { + if top_k <= 0 { + return Err(ValidationError::TopK); + } + Ok(top_k as u32) + } + }?; + + if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS { return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); } - if request.parameters.stop.len() > MAX_STOP_SEQUENCES { + + if stop_sequences.len() > MAX_STOP_SEQUENCES { return Err(ValidationError::StopSequence( MAX_STOP_SEQUENCES, - request.parameters.stop.len(), + stop_sequences.len(), )); } // If seed is None, assign a random one - let seed = match request.parameters.seed { + let seed = match seed { None => rng.gen(), Some(seed) => seed, }; @@ -147,21 +175,10 @@ fn validate( Err(ValidationError::InputLength(input_length, max_input_length)) } else { // Return ValidGenerateRequest - let GenerateParameters { - temperature, - repetition_penalty, - top_k, - top_p, - do_sample, - max_new_tokens, - stop: stop_sequences, - .. - } = request.parameters; - let parameters = NextTokenChooserParameters { temperature, repetition_penalty, - top_k: top_k as u32, + top_k, top_p, do_sample, seed, @@ -206,7 +223,7 @@ pub enum ValidationError { TopP, #[error("top_k must be strictly positive")] TopK, - #[error("max_new_tokens must be <= {0}")] + #[error("max_new_tokens must be strictly positive and <= {0}")] MaxNewTokens(u32), #[error("inputs must have less than {1} tokens. Given: {0}")] InputLength(usize, usize), diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py index 1dc6801b..e43fe501 100644 --- a/server/tests/test_utils.py +++ b/server/tests/test_utils.py @@ -9,6 +9,7 @@ from text_generation.utils import ( StopSequenceCriteria, StoppingCriteria, LocalEntryNotFoundError, + FinishReason ) @@ -24,13 +25,13 @@ def test_stop_sequence_criteria(): def test_stopping_criteria(): criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) assert criteria(65827, "/test") == (False, None) - assert criteria(30, ";") == (True, "stop_sequence") + assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE) def test_stopping_criteria_eos(): criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) assert criteria(1, "") == (False, None) - assert criteria(0, "") == (True, "eos_token") + assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN) def test_stopping_criteria_max(): @@ -39,7 +40,7 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) - assert criteria(1, "") == (True, "length") + assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) def test_weight_hub_files(): diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 30cd716a..d1117b80 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -7,6 +7,7 @@ from typing import List, Optional from transformers import PreTrainedTokenizerBase from text_generation.pb import generate_pb2 +from text_generation.pb.generate_pb2 import FinishReason class Batch(ABC): @@ -38,7 +39,7 @@ class Batch(ABC): class GeneratedText: text: str generated_tokens: int - finish_reason: str + finish_reason: FinishReason seed: Optional[int] def to_pb(self) -> generate_pb2.GeneratedText: diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 8bbc7377..c5534ca2 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -24,6 +24,7 @@ from transformers.generation.logits_process import ( ) from text_generation.pb import generate_pb2 +from text_generation.pb.generate_pb2 import FinishReason WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) @@ -129,15 +130,15 @@ class StoppingCriteria: def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: - return True, "length" + return True, FinishReason.FINISH_REASON_LENGTH if last_token == self.eos_token_id: - return True, "eos_token" + return True, FinishReason.FINISH_REASON_EOS_TOKEN self.current_output += last_output for stop_sequence_criteria in self.stop_sequence_criterias: if stop_sequence_criteria(self.current_output): - return True, "stop_sequence" + return True, FinishReason.FINISH_REASON_STOP_SEQUENCE return False, None