mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
parent
941cd42e0c
commit
1a2d68250a
@ -34,14 +34,16 @@ message NextTokenChooserParameters {
|
|||||||
uint32 top_k = 2;
|
uint32 top_k = 2;
|
||||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
float top_p = 3;
|
float top_p = 3;
|
||||||
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
float typical_p = 4;
|
||||||
/// apply sampling on the logits
|
/// apply sampling on the logits
|
||||||
bool do_sample = 4;
|
bool do_sample = 5;
|
||||||
/// random seed for sampling
|
/// random seed for sampling
|
||||||
uint64 seed = 5;
|
uint64 seed = 6;
|
||||||
/// repetition penalty
|
/// repetition penalty
|
||||||
float repetition_penalty = 6;
|
float repetition_penalty = 7;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 7;
|
bool watermark = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -41,6 +41,15 @@ pub(crate) struct GenerateParameters {
|
|||||||
)]
|
)]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
exclusive_minimum = 0.0,
|
||||||
|
maximum = 1.0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = 0.95
|
||||||
|
)]
|
||||||
|
pub typical_p: Option<f32>,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
@ -72,6 +81,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
|
typical_p: None,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
|
@ -231,6 +231,7 @@ mod tests {
|
|||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
top_p: 0.0,
|
top_p: 0.0,
|
||||||
|
typical_p: 0.0,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
|
@ -68,6 +68,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
|
typical_p: None,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
|
@ -153,6 +153,7 @@ fn validate(
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
|
typical_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop: stop_sequences,
|
stop: stop_sequences,
|
||||||
@ -171,22 +172,34 @@ fn validate(
|
|||||||
return Err(ValidationError::RepetitionPenalty);
|
return Err(ValidationError::RepetitionPenalty);
|
||||||
}
|
}
|
||||||
|
|
||||||
let top_p = top_p.unwrap_or(1.0);
|
// Different because the proto default value is not a valid value
|
||||||
if top_p <= 0.0 || top_p > 1.0 {
|
|
||||||
return Err(ValidationError::TopP);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different because the proto default value is 0 while it is not a valid value
|
|
||||||
// for the user
|
// for the user
|
||||||
let top_k: u32 = match top_k {
|
let top_p = top_p
|
||||||
None => Ok(0),
|
.map(|value| {
|
||||||
Some(top_k) => {
|
if value <= 0.0 || value >= 1.0 {
|
||||||
if top_k <= 0 {
|
return Err(ValidationError::TopP);
|
||||||
|
}
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(1.0))?;
|
||||||
|
|
||||||
|
let typical_p = typical_p
|
||||||
|
.map(|value| {
|
||||||
|
if value <= 0.0 || value >= 1.0 {
|
||||||
|
return Err(ValidationError::TypicalP);
|
||||||
|
}
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(1.0))?;
|
||||||
|
|
||||||
|
let top_k: u32 = top_k
|
||||||
|
.map(|value| {
|
||||||
|
if value <= 0 {
|
||||||
return Err(ValidationError::TopK);
|
return Err(ValidationError::TopK);
|
||||||
}
|
}
|
||||||
Ok(top_k as u32)
|
Ok(value as u32)
|
||||||
}
|
})
|
||||||
}?;
|
.unwrap_or(Ok(0))?;
|
||||||
|
|
||||||
if max_new_tokens == 0 {
|
if max_new_tokens == 0 {
|
||||||
return Err(ValidationError::MaxNewTokens);
|
return Err(ValidationError::MaxNewTokens);
|
||||||
@ -231,6 +244,7 @@ fn validate(
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
|
typical_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
@ -275,10 +289,12 @@ pub enum ValidationError {
|
|||||||
Temperature,
|
Temperature,
|
||||||
#[error("`repetition_penalty` must be strictly positive")]
|
#[error("`repetition_penalty` must be strictly positive")]
|
||||||
RepetitionPenalty,
|
RepetitionPenalty,
|
||||||
#[error("`top_p` must be > 0.0 and <= 1.0")]
|
#[error("`top_p` must be > 0.0 and < 1.0")]
|
||||||
TopP,
|
TopP,
|
||||||
#[error("`top_k` must be strictly positive")]
|
#[error("`top_k` must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
|
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||||
|
TypicalP,
|
||||||
#[error("`max_new_tokens` must be strictly positive")]
|
#[error("`max_new_tokens` must be strictly positive")]
|
||||||
MaxNewTokens,
|
MaxNewTokens,
|
||||||
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||||
|
@ -10,6 +10,7 @@ def default_pb_parameters():
|
|||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
|
typical_p=1.0,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from transformers import (
|
|||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
|
TypicalLogitsWarper,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
@ -41,6 +42,7 @@ class NextTokenChooser:
|
|||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
|
typical_p=None,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@ -64,6 +66,9 @@ class NextTokenChooser:
|
|||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
sampling = True
|
sampling = True
|
||||||
|
if typical_p is not None and typical_p < 1.0:
|
||||||
|
warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
|
sampling = True
|
||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
@ -92,6 +97,7 @@ class NextTokenChooser:
|
|||||||
repetition_penalty=pb.repetition_penalty,
|
repetition_penalty=pb.repetition_penalty,
|
||||||
top_k=pb.top_k,
|
top_k=pb.top_k,
|
||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
|
typical_p=pb.typical_p,
|
||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
|
Loading…
Reference in New Issue
Block a user