formatting

This commit is contained in:
OlivierDehaene 2023-02-02 18:59:21 +01:00
parent 4d00990ccd
commit 5de40eb078
6 changed files with 32 additions and 17 deletions

View File

@ -13,10 +13,20 @@ use validation::Validation;
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 0.5)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 0.5
)]
pub temperature: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 1.03)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 1.03
)]
pub repetition_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]

View File

@ -1,8 +1,8 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
Infer, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
@ -173,7 +173,8 @@ async fn generate(
example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "),
)
)]#[instrument(
)]
#[instrument(
skip(infer),
fields(
total_time,
@ -413,14 +414,21 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
};
(status_code, Json(ErrorResponse{error: err.to_string()}))
(
status_code,
Json(ErrorResponse {
error: err.to_string(),
}),
)
}
}
impl From<InferError> for Event {
fn from(err: InferError) -> Self {
Event::default()
.json_data(ErrorResponse{error: err.to_string()})
.json_data(ErrorResponse {
error: err.to_string(),
})
.unwrap()
}
}

View File

@ -149,7 +149,7 @@ fn validate(
}
}?;
if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
}

View File

@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
@ -283,8 +281,7 @@ def test_batch_concatenate(
assert len(generations) == 2
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
@ -306,8 +303,7 @@ def test_batch_concatenate(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id

View File

@ -9,7 +9,7 @@ from text_generation.utils import (
StopSequenceCriteria,
StoppingCriteria,
LocalEntryNotFoundError,
FinishReason
FinishReason,
)

View File

@ -28,6 +28,7 @@ from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)