mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: update grammar states after filter
This commit is contained in:
parent
ff6e8d9e23
commit
a1c630d5c1
@ -45,7 +45,7 @@ pub async fn run(
|
||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||
watermark,
|
||||
fsm_grammar_state: 0,
|
||||
fsm_grammar_state: Vec::new(),
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
|
@ -128,8 +128,8 @@ impl Client {
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
fsm_grammar_state: 0,
|
||||
grammar: Vec::new(),
|
||||
fsm_grammar_state: Vec::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
|
@ -45,8 +45,8 @@ impl Health {
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
fsm_grammar_state: 0,
|
||||
grammar: Vec::new(),
|
||||
fsm_grammar_state: Vec::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -201,7 +201,10 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
#[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")]
|
||||
#[serde(
|
||||
default,
|
||||
deserialize_with = "json_object_or_string_to_string::deserialize"
|
||||
)]
|
||||
pub grammar: String,
|
||||
}
|
||||
|
||||
|
@ -368,8 +368,8 @@ mod tests {
|
||||
repetition_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
fsm_grammar_state: 0,
|
||||
grammar: Vec::new(),
|
||||
fsm_grammar_state: Vec::new(),
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -293,6 +293,11 @@ impl Validation {
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
// initialize the grammar parameter
|
||||
let grammar = vec![grammar];
|
||||
// init the start state of the grammar
|
||||
let fsm_grammar_state = vec![0];
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
@ -304,7 +309,7 @@ impl Validation {
|
||||
seed,
|
||||
watermark,
|
||||
grammar,
|
||||
fsm_grammar_state: 0,
|
||||
fsm_grammar_state,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -409,7 +409,9 @@ class HeterogeneousNextTokenChooser:
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
|
||||
if self.use_grammar or any(self.do_sample):
|
||||
self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
|
||||
_, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
|
||||
self.fsm_grammar_states = new_fsm_grammar_states
|
||||
self.grammars = new_grammars
|
||||
else:
|
||||
self.choice = Greedy()
|
||||
|
||||
@ -477,6 +479,11 @@ class Grammar:
|
||||
if fsm_grammar_states[i] == -1:
|
||||
continue
|
||||
|
||||
# if grammar is '' or None, return the greedy token
|
||||
if grammars[i] == "" or grammars[i] is None:
|
||||
empty[i] = logits[i].argmax().item()
|
||||
continue
|
||||
|
||||
# this is cached and should be fast after the first time
|
||||
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
@ -546,9 +553,7 @@ class Grammar:
|
||||
new_fsm_grammar_states.append(fsm_grammar_states[i])
|
||||
new_grammars.append(grammars[i])
|
||||
|
||||
self.fsm_state = new_fsm_grammar_states
|
||||
self.fsm = new_grammars
|
||||
return self
|
||||
return self, new_fsm_grammar_states, new_grammars
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
|
Loading…
Reference in New Issue
Block a user