fix: update grammar states after filter

This commit is contained in:
drbh 2024-02-09 17:40:19 +00:00
parent ff6e8d9e23
commit a1c630d5c1
7 changed files with 26 additions and 13 deletions

View File

@ -45,7 +45,7 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark,
fsm_grammar_state: 0,
fsm_grammar_state: Vec::new(),
};
// Initialize terminal properties

View File

@ -128,8 +128,8 @@ impl Client {
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
fsm_grammar_state: 0,
grammar: Vec::new(),
fsm_grammar_state: Vec::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,

View File

@ -45,8 +45,8 @@ impl Health {
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
fsm_grammar_state: 0,
grammar: Vec::new(),
fsm_grammar_state: Vec::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,

View File

@ -201,7 +201,10 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
#[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")]
#[serde(
default,
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
pub grammar: String,
}

View File

@ -368,8 +368,8 @@ mod tests {
repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
fsm_grammar_state: 0,
grammar: Vec::new(),
fsm_grammar_state: Vec::new(),
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,

View File

@ -293,6 +293,11 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
// initialize the grammar parameter
let grammar = vec![grammar];
// init the start state of the grammar
let fsm_grammar_state = vec![0];
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
@ -304,7 +309,7 @@ impl Validation {
seed,
watermark,
grammar,
fsm_grammar_state: 0,
fsm_grammar_state,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,

View File

@ -409,7 +409,9 @@ class HeterogeneousNextTokenChooser:
self.do_sample = [self.do_sample[i] for i in indices]
if self.use_grammar or any(self.do_sample):
self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
_, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
self.fsm_grammar_states = new_fsm_grammar_states
self.grammars = new_grammars
else:
self.choice = Greedy()
@ -477,6 +479,11 @@ class Grammar:
if fsm_grammar_states[i] == -1:
continue
# if grammar is '' or None, return the greedy token
if grammars[i] == "" or grammars[i] is None:
empty[i] = logits[i].argmax().item()
continue
# this is cached and should be fast after the first time
fsm = self.compile_fsm(grammars[i], self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
@ -546,9 +553,7 @@ class Grammar:
new_fsm_grammar_states.append(fsm_grammar_states[i])
new_grammars.append(grammars[i])
self.fsm_state = new_fsm_grammar_states
self.fsm = new_grammars
return self
return self, new_fsm_grammar_states, new_grammars
class HeterogeneousSampling: