diff --git a/router/src/lib.rs b/router/src/lib.rs index a545f65e..8e3199dd 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -13,10 +13,20 @@ use validation::Validation; #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { #[serde(default)] - #[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 0.5)] + #[schema( + exclusive_minimum = 0.0, + nullable = true, + default = "null", + example = 0.5 + )] pub temperature: Option, #[serde(default)] - #[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 1.03)] + #[schema( + exclusive_minimum = 0.0, + nullable = true, + default = "null", + example = 1.03 + )] pub repetition_penalty: Option, #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] diff --git a/router/src/server.rs b/router/src/server.rs index 8094df1f..1a332088 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,8 +1,8 @@ /// HTTP Server logic use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation, + Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, + Infer, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; @@ -173,7 +173,8 @@ async fn generate( example = json!({"error": "Incomplete generation"}), content_type="text/event-stream "), ) -)]#[instrument( +)] +#[instrument( skip(infer), fields( total_time, @@ -413,14 +414,21 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, }; - (status_code, Json(ErrorResponse{error: err.to_string()})) + ( + status_code, + Json(ErrorResponse { + error: err.to_string(), + }), + ) } } impl From for Event { fn from(err: InferError) -> Self { Event::default() - .json_data(ErrorResponse{error: err.to_string()}) + .json_data(ErrorResponse { + error: err.to_string(), + }) .unwrap() } } diff --git a/router/src/validation.rs b/router/src/validation.rs index b7b33a19..3cca48e9 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -149,7 +149,7 @@ fn validate( } }?; - if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS { + if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS { return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); } diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 871c0da0..b06d57f5 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) assert len(generations) == 1 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" ) assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( @@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi( assert len(generations) == 1 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" ) assert ( generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id @@ -283,8 +281,7 @@ def test_batch_concatenate( assert len(generations) == 2 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" ) assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( @@ -306,8 +303,7 @@ def test_batch_concatenate( assert len(generations) == 1 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" ) assert ( generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py index e43fe501..ffe9be65 100644 --- a/server/tests/test_utils.py +++ b/server/tests/test_utils.py @@ -9,7 +9,7 @@ from text_generation.utils import ( StopSequenceCriteria, StoppingCriteria, LocalEntryNotFoundError, - FinishReason + FinishReason, ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index c5534ca2..ea97ed4a 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -28,6 +28,7 @@ from text_generation.pb.generate_pb2 import FinishReason WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) + class Sampling: def __init__(self, seed: int, device: str = "cpu"): self.generator = torch.Generator(device)