feat: support typical sampling (#114)

closes #112
This commit is contained in:
OlivierDehaene 2023-03-09 11:33:57 +01:00 committed by GitHub
parent 941cd42e0c
commit 1a2d68250a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 18 deletions

View File

@ -34,14 +34,16 @@ message NextTokenChooserParameters {
uint32 top_k = 2; uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3; float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 4; bool do_sample = 5;
/// random seed for sampling /// random seed for sampling
uint64 seed = 5; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 6; float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 7; bool watermark = 8;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -41,6 +41,15 @@ pub(crate) struct GenerateParameters {
)] )]
pub top_p: Option<f32>, pub top_p: Option<f32>,
#[serde(default)] #[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub typical_p: Option<f32>,
#[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub do_sample: bool, pub do_sample: bool,
#[serde(default = "default_max_new_tokens")] #[serde(default = "default_max_new_tokens")]
@ -72,6 +81,7 @@ fn default_parameters() -> GenerateParameters {
repetition_penalty: None, repetition_penalty: None,
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None,
do_sample: false, do_sample: false,
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
return_full_text: None, return_full_text: None,

View File

@ -231,6 +231,7 @@ mod tests {
temperature: 0.0, temperature: 0.0,
top_k: 0, top_k: 0,
top_p: 0.0, top_p: 0.0,
typical_p: 0.0,
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 0.0, repetition_penalty: 0.0,

View File

@ -68,6 +68,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
repetition_penalty: None, repetition_penalty: None,
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None,
do_sample: false, do_sample: false,
max_new_tokens: 1, max_new_tokens: 1,
return_full_text: None, return_full_text: None,

View File

@ -153,6 +153,7 @@ fn validate(
repetition_penalty, repetition_penalty,
top_k, top_k,
top_p, top_p,
typical_p,
do_sample, do_sample,
max_new_tokens, max_new_tokens,
stop: stop_sequences, stop: stop_sequences,
@ -171,22 +172,34 @@ fn validate(
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
let top_p = top_p.unwrap_or(1.0); // Different because the proto default value is not a valid value
if top_p <= 0.0 || top_p > 1.0 {
return Err(ValidationError::TopP);
}
// Different because the proto default value is 0 while it is not a valid value
// for the user // for the user
let top_k: u32 = match top_k { let top_p = top_p
None => Ok(0), .map(|value| {
Some(top_k) => { if value <= 0.0 || value >= 1.0 {
if top_k <= 0 { return Err(ValidationError::TopP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let typical_p = typical_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TypicalP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let top_k: u32 = top_k
.map(|value| {
if value <= 0 {
return Err(ValidationError::TopK); return Err(ValidationError::TopK);
} }
Ok(top_k as u32) Ok(value as u32)
} })
}?; .unwrap_or(Ok(0))?;
if max_new_tokens == 0 { if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens); return Err(ValidationError::MaxNewTokens);
@ -231,6 +244,7 @@ fn validate(
repetition_penalty, repetition_penalty,
top_k, top_k,
top_p, top_p,
typical_p,
do_sample, do_sample,
seed, seed,
watermark, watermark,
@ -275,10 +289,12 @@ pub enum ValidationError {
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty, RepetitionPenalty,
#[error("`top_p` must be > 0.0 and <= 1.0")] #[error("`top_p` must be > 0.0 and < 1.0")]
TopP, TopP,
#[error("`top_k` must be strictly positive")] #[error("`top_k` must be strictly positive")]
TopK, TopK,
#[error("`typical_p` must be > 0.0 and < 1.0")]
TypicalP,
#[error("`max_new_tokens` must be strictly positive")] #[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens, MaxNewTokens,
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")] #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]

View File

@ -10,6 +10,7 @@ def default_pb_parameters():
repetition_penalty=1.0, repetition_penalty=1.0,
top_k=0, top_k=0,
top_p=1.0, top_p=1.0,
typical_p=1.0,
do_sample=False, do_sample=False,
) )

View File

@ -6,6 +6,7 @@ from transformers import (
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
@ -41,6 +42,7 @@ class NextTokenChooser:
repetition_penalty=1.0, repetition_penalty=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None,
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
@ -64,6 +66,9 @@ class NextTokenChooser:
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p)) warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True sampling = True
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p))
sampling = True
self.warpers = warpers self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
@ -92,6 +97,7 @@ class NextTokenChooser:
repetition_penalty=pb.repetition_penalty, repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
typical_p=pb.typical_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=device, device=device,