From ea915ad7d711b78436801f76bc6e03f743e8a77b Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 15 Jul 2024 13:51:11 +0000 Subject: [PATCH] Add support for no_repeat_ngram_size --- proto/generate.proto | 2 ++ router/src/lib.rs | 22 +++++++++++++++++++ router/src/server.rs | 2 ++ .../utils/logits_process.py | 7 +++++- server/text_generation_server/utils/tokens.py | 5 ++++- 5 files changed, 36 insertions(+), 2 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f..68cd5fc9 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -74,6 +74,8 @@ message NextTokenChooserParameters { float repetition_penalty = 7; /// frequency penalty float frequency_penalty = 9; + /// no_repeat_ngram_size + uint32 no_repeat_ngram_size = 12; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; /// grammar (applied if not empty) diff --git a/router/src/lib.rs b/router/src/lib.rs index 14bb8270..c12d421b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -223,6 +223,13 @@ pub(crate) struct GenerateParameters { )] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// The number of highest probability vocabulary tokens to keep for top-k-filtering. #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] @@ -327,6 +334,7 @@ fn default_parameters() -> GenerateParameters { temperature: None, repetition_penalty: None, frequency_penalty: None, + no_repeat_ngram_size: None, top_k: None, top_p: None, typical_p: None, @@ -424,6 +432,13 @@ pub struct CompletionRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// Up to 4 sequences where the API will stop generating further tokens. #[serde(default)] #[schema(nullable = true, example = "null")] @@ -740,6 +755,13 @@ pub(crate) struct ChatRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// UNUSED /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, diff --git a/router/src/server.rs b/router/src/server.rs index dcbaa2ad..2972f820 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -654,6 +654,7 @@ async fn completions( temperature, repetition_penalty: req.repetition_penalty, frequency_penalty: req.frequency_penalty, + no_repeat_ngram_size: req.no_repeat_ngram_size, top_k: None, top_p: req.top_p, typical_p: None, @@ -1100,6 +1101,7 @@ async fn chat_completions( temperature, repetition_penalty, frequency_penalty: req.frequency_penalty, + no_repeat_ngram_size: req.no_repeat_ngram_size, top_k: None, top_p: req.top_p, typical_p: None, diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9abd886f..452691a2 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -18,6 +18,7 @@ from transformers import ( TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + NoRepeatNGramLogitsProcessor ) mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -30,6 +31,7 @@ class StaticWarper: top_k=None, top_p=None, typical_p=None, + no_repeat_ngram_size=None, ): self.warpers = [] @@ -42,6 +44,8 @@ class StaticWarper: self.warpers.append(TopPLogitsWarper(top_p=top_p)) if typical_p is not None and typical_p < 1.0: self.warpers.append(TypicalLogitsWarper(mass=typical_p)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + self.warpers.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) self.cuda_graph = None self.static_scores = None @@ -82,9 +86,10 @@ def static_warper( top_k: Optional[int], top_p: Optional[float], typical_p: Optional[float], + no_repeat_ngram_size: Optional[int], ) -> StaticWarper: return StaticWarper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9ab49665..b7a794b9 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -28,6 +28,7 @@ class NextTokenChooser: temperature: float = 1.0, repetition_penalty: float = 1.0, frequency_penalty: float = 0.0, + no_repeat_ngram_size: int = 0, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, @@ -64,10 +65,11 @@ class NextTokenChooser: or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) + or (no_repeat_ngram_size is not None and no_repeat_ngram_size > 0) ) if has_warpers: self.static_warper = static_warper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size ) else: self.static_warper = None @@ -116,6 +118,7 @@ class NextTokenChooser: temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, frequency_penalty=pb.frequency_penalty, + no_repeat_ngram_size=pb.no_repeat_ngram_size, top_k=pb.top_k, top_p=pb.top_p, typical_p=pb.typical_p,