mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
formatting
This commit is contained in:
parent
4d00990ccd
commit
5de40eb078
@ -13,10 +13,20 @@ use validation::Validation;
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 0.5)]
|
||||
#[schema(
|
||||
exclusive_minimum = 0.0,
|
||||
nullable = true,
|
||||
default = "null",
|
||||
example = 0.5
|
||||
)]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null", example = 1.03)]
|
||||
#[schema(
|
||||
exclusive_minimum = 0.0,
|
||||
nullable = true,
|
||||
default = "null",
|
||||
example = 1.03
|
||||
)]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
||||
|
@ -1,8 +1,8 @@
|
||||
/// HTTP Server logic
|
||||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::{
|
||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||
Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
@ -173,7 +173,8 @@ async fn generate(
|
||||
example = json!({"error": "Incomplete generation"}),
|
||||
content_type="text/event-stream "),
|
||||
)
|
||||
)]#[instrument(
|
||||
)]
|
||||
#[instrument(
|
||||
skip(infer),
|
||||
fields(
|
||||
total_time,
|
||||
@ -413,14 +414,21 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
(status_code, Json(ErrorResponse{error: err.to_string()}))
|
||||
(
|
||||
status_code,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InferError> for Event {
|
||||
fn from(err: InferError) -> Self {
|
||||
Event::default()
|
||||
.json_data(ErrorResponse{error: err.to_string()})
|
||||
.json_data(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ fn validate(
|
||||
}
|
||||
}?;
|
||||
|
||||
if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
||||
}
|
||||
|
||||
|
@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
||||
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert (
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
@ -283,8 +281,7 @@ def test_batch_concatenate(
|
||||
|
||||
assert len(generations) == 2
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
@ -306,8 +303,7 @@ def test_batch_concatenate(
|
||||
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert (
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
|
@ -9,7 +9,7 @@ from text_generation.utils import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
LocalEntryNotFoundError,
|
||||
FinishReason
|
||||
FinishReason,
|
||||
)
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ from text_generation.pb.generate_pb2 import FinishReason
|
||||
|
||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
|
Loading…
Reference in New Issue
Block a user