formatting

This commit is contained in:
OlivierDehaene 2023-02-02 18:59:21 +01:00
parent 4d00990ccd
commit 5de40eb078
6 changed files with 32 additions and 17 deletions

View File

@ -13,10 +13,20 @@ use validation::Validation;
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 0.5)] #[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 0.5
)]
pub temperature: Option<f32>, pub temperature: Option<f32>,
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 1.03)] #[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 1.03
)]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]

View File

@ -1,8 +1,8 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation, Infer, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
@ -173,7 +173,8 @@ async fn generate(
example = json!({"error": "Incomplete generation"}), example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "), content_type="text/event-stream "),
) )
)]#[instrument( )]
#[instrument(
skip(infer), skip(infer),
fields( fields(
total_time, total_time,
@ -413,14 +414,21 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
}; };
(status_code, Json(ErrorResponse{error: err.to_string()})) (
status_code,
Json(ErrorResponse {
error: err.to_string(),
}),
)
} }
} }
impl From<InferError> for Event { impl From<InferError> for Event {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {
Event::default() Event::default()
.json_data(ErrorResponse{error: err.to_string()}) .json_data(ErrorResponse {
error: err.to_string(),
})
.unwrap() .unwrap()
} }
} }

View File

@ -149,7 +149,7 @@ fn validate(
} }
}?; }?;
if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS { if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
} }

View File

@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
@ -283,8 +281,7 @@ def test_batch_concatenate(
assert len(generations) == 2 assert len(generations) == 2
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -306,8 +303,7 @@ def test_batch_concatenate(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id

View File

@ -9,7 +9,7 @@ from text_generation.utils import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
LocalEntryNotFoundError, LocalEntryNotFoundError,
FinishReason FinishReason,
) )

View File

@ -28,6 +28,7 @@ from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling: class Sampling:
def __init__(self, seed: int, device: str = "cpu"): def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device) self.generator = torch.Generator(device)