diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f..68cd5fc9 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -74,6 +74,8 @@ message NextTokenChooserParameters { float repetition_penalty = 7; /// frequency penalty float frequency_penalty = 9; + /// no_repeat_ngram_size + uint32 no_repeat_ngram_size = 12; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; /// grammar (applied if not empty) diff --git a/router/src/lib.rs b/router/src/lib.rs index f856406d..463fd795 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -226,6 +226,13 @@ pub(crate) struct GenerateParameters { )] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// The number of highest probability vocabulary tokens to keep for top-k-filtering. #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] @@ -330,6 +337,7 @@ fn default_parameters() -> GenerateParameters { temperature: None, repetition_penalty: None, frequency_penalty: None, + no_repeat_ngram_size: None, top_k: None, top_p: None, typical_p: None, @@ -427,6 +435,13 @@ pub struct CompletionRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// Up to 4 sequences where the API will stop generating further tokens. #[serde(default)] #[schema(nullable = true, example = "null")] @@ -743,6 +758,13 @@ pub(crate) struct ChatRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// UNUSED /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, diff --git a/router/src/server.rs b/router/src/server.rs index d3a280ca..25349524 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -653,6 +653,7 @@ async fn completions( temperature, repetition_penalty: req.repetition_penalty, frequency_penalty: req.frequency_penalty, + no_repeat_ngram_size: req.no_repeat_ngram_size, top_k: None, top_p: req.top_p, typical_p: None, @@ -1099,6 +1100,7 @@ async fn chat_completions( temperature, repetition_penalty, frequency_penalty: req.frequency_penalty, + no_repeat_ngram_size: req.no_repeat_ngram_size, top_k: None, top_p: req.top_p, typical_p: None, diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 6b915437..ddc6de73 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -18,6 +18,7 @@ from transformers import ( TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + NoRepeatNGramLogitsProcessor ) mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -30,6 +31,7 @@ class StaticWarper: top_k=None, top_p=None, typical_p=None, + no_repeat_ngram_size=None, ): self.warpers = [] @@ -42,6 +44,8 @@ class StaticWarper: self.warpers.append(TopPLogitsWarper(top_p=top_p)) if typical_p is not None and typical_p < 1.0: self.warpers.append(TypicalLogitsWarper(mass=typical_p)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + self.warpers.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) self.cuda_graph = None self.static_scores = None @@ -82,9 +86,10 @@ def static_warper( top_k: Optional[int], top_p: Optional[float], typical_p: Optional[float], + no_repeat_ngram_size: Optional[int], ) -> StaticWarper: return StaticWarper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 22f86b60..a34273bc 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -29,6 +29,7 @@ class NextTokenChooser: temperature: float = 1.0, repetition_penalty: float = 1.0, frequency_penalty: float = 0.0, + no_repeat_ngram_size: int = 0, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, @@ -65,10 +66,11 @@ class NextTokenChooser: or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) + or (no_repeat_ngram_size is not None and no_repeat_ngram_size > 0) ) if has_warpers: self.static_warper = static_warper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size ) else: self.static_warper = None @@ -117,6 +119,7 @@ class NextTokenChooser: temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, frequency_penalty=pb.frequency_penalty, + no_repeat_ngram_size=pb.no_repeat_ngram_size, top_k=pb.top_k, top_p=pb.top_p, typical_p=pb.typical_p,