fix: improve completion request params and comments

This commit is contained in:
drbh 2024-02-21 11:31:11 -05:00
parent 19c0248985
commit 544f848bde
2 changed files with 33 additions and 7 deletions

View File

@ -268,25 +268,50 @@ fn default_parameters() -> GenerateParameters {
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub struct CompletionRequest {
/// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
/// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")]
pub prompt: String,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(default = "32")]
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
/// make the output more random, while lower values like 0.2 will make it more
/// focused and deterministic.
///
/// We generally recommend altering this or top_p but not both.
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. #[serde(default)]
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
pub stream: Option<bool>,
#[serde(default = "bool::default")]
pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.
/// please see the completion_template field in the model's tokenizer_config.json file for completion template.
#[serde(default)]
pub suffix: Option<String>,
#[serde(default)]
pub repetition_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
}

View File

@ -574,8 +574,9 @@ async fn completions(
Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100));
let stream = req.stream.unwrap_or_default();
let seed = req.seed;
let inputs = match infer.apply_completion_template(req.prompt, req.suffix) {