mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Add support for no_repeat_ngram_size
This commit is contained in:
parent
dbb23fbfa8
commit
28e6a504c0
@ -74,6 +74,8 @@ message NextTokenChooserParameters {
|
|||||||
float repetition_penalty = 7;
|
float repetition_penalty = 7;
|
||||||
/// frequency penalty
|
/// frequency penalty
|
||||||
float frequency_penalty = 9;
|
float frequency_penalty = 9;
|
||||||
|
/// no_repeat_ngram_size
|
||||||
|
uint32 no_repeat_ngram_size = 12;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
/// grammar (applied if not empty)
|
/// grammar (applied if not empty)
|
||||||
|
@ -226,6 +226,13 @@ pub(crate) struct GenerateParameters {
|
|||||||
)]
|
)]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
|
||||||
|
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
|
||||||
|
/// generating the same n-grams in the completion.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "12")]
|
||||||
|
pub no_repeat_ngram_size: Option<u32>,
|
||||||
|
|
||||||
/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
||||||
@ -330,6 +337,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
temperature: None,
|
temperature: None,
|
||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
frequency_penalty: None,
|
frequency_penalty: None,
|
||||||
|
no_repeat_ngram_size: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
@ -427,6 +435,13 @@ pub struct CompletionRequest {
|
|||||||
#[schema(example = "1.0")]
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
|
||||||
|
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
|
||||||
|
/// generating the same n-grams in the completion.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "12")]
|
||||||
|
pub no_repeat_ngram_size: Option<u32>,
|
||||||
|
|
||||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
@ -743,6 +758,13 @@ pub(crate) struct ChatRequest {
|
|||||||
#[schema(example = "1.0")]
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
|
||||||
|
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
|
||||||
|
/// generating the same n-grams in the completion.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "12")]
|
||||||
|
pub no_repeat_ngram_size: Option<u32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||||
|
@ -653,6 +653,7 @@ async fn completions(
|
|||||||
temperature,
|
temperature,
|
||||||
repetition_penalty: req.repetition_penalty,
|
repetition_penalty: req.repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
|
no_repeat_ngram_size: req.no_repeat_ngram_size,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
@ -1099,6 +1100,7 @@ async fn chat_completions(
|
|||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
|
no_repeat_ngram_size: req.no_repeat_ngram_size,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
|
@ -18,6 +18,7 @@ from transformers import (
|
|||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
|
NoRepeatNGramLogitsProcessor
|
||||||
)
|
)
|
||||||
|
|
||||||
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
@ -30,6 +31,7 @@ class StaticWarper:
|
|||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
typical_p=None,
|
typical_p=None,
|
||||||
|
no_repeat_ngram_size=None,
|
||||||
):
|
):
|
||||||
self.warpers = []
|
self.warpers = []
|
||||||
|
|
||||||
@ -42,6 +44,8 @@ class StaticWarper:
|
|||||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
if typical_p is not None and typical_p < 1.0:
|
if typical_p is not None and typical_p < 1.0:
|
||||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
|
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
||||||
|
self.warpers.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
||||||
|
|
||||||
self.cuda_graph = None
|
self.cuda_graph = None
|
||||||
self.static_scores = None
|
self.static_scores = None
|
||||||
@ -82,9 +86,10 @@ def static_warper(
|
|||||||
top_k: Optional[int],
|
top_k: Optional[int],
|
||||||
top_p: Optional[float],
|
top_p: Optional[float],
|
||||||
typical_p: Optional[float],
|
typical_p: Optional[float],
|
||||||
|
no_repeat_ngram_size: Optional[int],
|
||||||
) -> StaticWarper:
|
) -> StaticWarper:
|
||||||
return StaticWarper(
|
return StaticWarper(
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ class NextTokenChooser:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
|
no_repeat_ngram_size: int = 0,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
@ -65,10 +66,11 @@ class NextTokenChooser:
|
|||||||
or (top_k is not None and top_k != 0)
|
or (top_k is not None and top_k != 0)
|
||||||
or (top_p is not None and top_p < 1.0)
|
or (top_p is not None and top_p < 1.0)
|
||||||
or (typical_p is not None and typical_p < 1.0)
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
|
or (no_repeat_ngram_size is not None and no_repeat_ngram_size > 0)
|
||||||
)
|
)
|
||||||
if has_warpers:
|
if has_warpers:
|
||||||
self.static_warper = static_warper(
|
self.static_warper = static_warper(
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
@ -117,6 +119,7 @@ class NextTokenChooser:
|
|||||||
temperature=pb.temperature,
|
temperature=pb.temperature,
|
||||||
repetition_penalty=pb.repetition_penalty,
|
repetition_penalty=pb.repetition_penalty,
|
||||||
frequency_penalty=pb.frequency_penalty,
|
frequency_penalty=pb.frequency_penalty,
|
||||||
|
no_repeat_ngram_size=pb.no_repeat_ngram_size,
|
||||||
top_k=pb.top_k,
|
top_k=pb.top_k,
|
||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
typical_p=pb.typical_p,
|
typical_p=pb.typical_p,
|
||||||
|
Loading…
Reference in New Issue
Block a user