From a1c630d5c16cc58d0f38b76c1135e1cbdc31ec79 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Feb 2024 17:40:19 +0000 Subject: [PATCH] fix: update grammar states after filter --- benchmark/src/lib.rs | 2 +- router/client/src/client.rs | 4 ++-- router/src/health.rs | 4 ++-- router/src/lib.rs | 5 ++++- router/src/queue.rs | 4 ++-- router/src/validation.rs | 7 ++++++- server/text_generation_server/utils/tokens.py | 13 +++++++++---- 7 files changed, 26 insertions(+), 13 deletions(-) diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index a0ef0fe6..e290665c 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -45,7 +45,7 @@ pub async fn run( repetition_penalty: repetition_penalty.unwrap_or(1.0), frequency_penalty: frequency_penalty.unwrap_or(0.0), watermark, - fsm_grammar_state: 0, + fsm_grammar_state: Vec::new(), }; // Initialize terminal properties diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 38e6e0e3..f26b6a69 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -128,8 +128,8 @@ impl Client { repetition_penalty: 1.2, frequency_penalty: 0.1, watermark: true, - grammar: String::new(), - fsm_grammar_state: 0, + grammar: Vec::new(), + fsm_grammar_state: Vec::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/src/health.rs b/router/src/health.rs index f3cac17e..0605dca8 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -45,8 +45,8 @@ impl Health { repetition_penalty: 1.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - fsm_grammar_state: 0, + grammar: Vec::new(), + fsm_grammar_state: Vec::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index 03b7ce88..b13d84ee 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -201,7 +201,10 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] pub top_n_tokens: Option, - #[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")] + #[serde( + default, + deserialize_with = "json_object_or_string_to_string::deserialize" + )] pub grammar: String, } diff --git a/router/src/queue.rs b/router/src/queue.rs index 0162b906..21467aa7 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -368,8 +368,8 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - fsm_grammar_state: 0, + grammar: Vec::new(), + fsm_grammar_state: Vec::new(), }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index 0455411d..83a68435 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -293,6 +293,11 @@ impl Validation { .validate_input(request.inputs, truncate, max_new_tokens) .await?; + // initialize the grammar parameter + let grammar = vec![grammar]; + // init the start state of the grammar + let fsm_grammar_state = vec![0]; + let parameters = NextTokenChooserParameters { temperature, repetition_penalty, @@ -304,7 +309,7 @@ impl Validation { seed, watermark, grammar, - fsm_grammar_state: 0, + fsm_grammar_state, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 3178c2f2..67fdcec0 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -409,7 +409,9 @@ class HeterogeneousNextTokenChooser: self.do_sample = [self.do_sample[i] for i in indices] if self.use_grammar or any(self.do_sample): - self.choice.filter(indices, self.fsm_grammar_states, self.grammars) + _, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars) + self.fsm_grammar_states = new_fsm_grammar_states + self.grammars = new_grammars else: self.choice = Greedy() @@ -477,6 +479,11 @@ class Grammar: if fsm_grammar_states[i] == -1: continue + # if grammar is '' or None, return the greedy token + if grammars[i] == "" or grammars[i] is None: + empty[i] = logits[i].argmax().item() + continue + # this is cached and should be fast after the first time fsm = self.compile_fsm(grammars[i], self.tokenizer) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) @@ -546,9 +553,7 @@ class Grammar: new_fsm_grammar_states.append(fsm_grammar_states[i]) new_grammars.append(grammars[i]) - self.fsm_state = new_fsm_grammar_states - self.fsm = new_grammars - return self + return self, new_fsm_grammar_states, new_grammars class HeterogeneousSampling: