fix: update grammar states after filter

This commit is contained in:
drbh 2024-02-09 17:40:19 +00:00
parent ff6e8d9e23
commit a1c630d5c1
7 changed files with 26 additions and 13 deletions

View File

@ -45,7 +45,7 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0), frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
fsm_grammar_state: 0, fsm_grammar_state: Vec::new(),
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -128,8 +128,8 @@ impl Client {
repetition_penalty: 1.2, repetition_penalty: 1.2,
frequency_penalty: 0.1, frequency_penalty: 0.1,
watermark: true, watermark: true,
grammar: String::new(), grammar: Vec::new(),
fsm_grammar_state: 0, fsm_grammar_state: Vec::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -45,8 +45,8 @@ impl Health {
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: Vec::new(),
fsm_grammar_state: 0, fsm_grammar_state: Vec::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -201,7 +201,10 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
#[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")] #[serde(
default,
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
pub grammar: String, pub grammar: String,
} }

View File

@ -368,8 +368,8 @@ mod tests {
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: Vec::new(),
fsm_grammar_state: 0, fsm_grammar_state: Vec::new(),
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -293,6 +293,11 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
// initialize the grammar parameter
let grammar = vec![grammar];
// init the start state of the grammar
let fsm_grammar_state = vec![0];
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
@ -304,7 +309,7 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
fsm_grammar_state: 0, fsm_grammar_state,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -409,7 +409,9 @@ class HeterogeneousNextTokenChooser:
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
if self.use_grammar or any(self.do_sample): if self.use_grammar or any(self.do_sample):
self.choice.filter(indices, self.fsm_grammar_states, self.grammars) _, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
self.fsm_grammar_states = new_fsm_grammar_states
self.grammars = new_grammars
else: else:
self.choice = Greedy() self.choice = Greedy()
@ -477,6 +479,11 @@ class Grammar:
if fsm_grammar_states[i] == -1: if fsm_grammar_states[i] == -1:
continue continue
# if grammar is '' or None, return the greedy token
if grammars[i] == "" or grammars[i] is None:
empty[i] = logits[i].argmax().item()
continue
# this is cached and should be fast after the first time # this is cached and should be fast after the first time
fsm = self.compile_fsm(grammars[i], self.tokenizer) fsm = self.compile_fsm(grammars[i], self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
@ -546,9 +553,7 @@ class Grammar:
new_fsm_grammar_states.append(fsm_grammar_states[i]) new_fsm_grammar_states.append(fsm_grammar_states[i])
new_grammars.append(grammars[i]) new_grammars.append(grammars[i])
self.fsm_state = new_fsm_grammar_states return self, new_fsm_grammar_states, new_grammars
self.fsm = new_grammars
return self
class HeterogeneousSampling: class HeterogeneousSampling: