mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: update grammar states after filter
This commit is contained in:
parent
ff6e8d9e23
commit
a1c630d5c1
@ -45,7 +45,7 @@ pub async fn run(
|
|||||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||||
watermark,
|
watermark,
|
||||||
fsm_grammar_state: 0,
|
fsm_grammar_state: Vec::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
|
@ -128,8 +128,8 @@ impl Client {
|
|||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
frequency_penalty: 0.1,
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
grammar: Vec::new(),
|
||||||
fsm_grammar_state: 0,
|
fsm_grammar_state: Vec::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -45,8 +45,8 @@ impl Health {
|
|||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: Vec::new(),
|
||||||
fsm_grammar_state: 0,
|
fsm_grammar_state: Vec::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -201,7 +201,10 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
#[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")]
|
#[serde(
|
||||||
|
default,
|
||||||
|
deserialize_with = "json_object_or_string_to_string::deserialize"
|
||||||
|
)]
|
||||||
pub grammar: String,
|
pub grammar: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -368,8 +368,8 @@ mod tests {
|
|||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: Vec::new(),
|
||||||
fsm_grammar_state: 0,
|
fsm_grammar_state: Vec::new(),
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -293,6 +293,11 @@ impl Validation {
|
|||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// initialize the grammar parameter
|
||||||
|
let grammar = vec![grammar];
|
||||||
|
// init the start state of the grammar
|
||||||
|
let fsm_grammar_state = vec![0];
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
@ -304,7 +309,7 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
fsm_grammar_state: 0,
|
fsm_grammar_state,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -409,7 +409,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
if self.use_grammar or any(self.do_sample):
|
if self.use_grammar or any(self.do_sample):
|
||||||
self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
|
_, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
|
||||||
|
self.fsm_grammar_states = new_fsm_grammar_states
|
||||||
|
self.grammars = new_grammars
|
||||||
else:
|
else:
|
||||||
self.choice = Greedy()
|
self.choice = Greedy()
|
||||||
|
|
||||||
@ -477,6 +479,11 @@ class Grammar:
|
|||||||
if fsm_grammar_states[i] == -1:
|
if fsm_grammar_states[i] == -1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# if grammar is '' or None, return the greedy token
|
||||||
|
if grammars[i] == "" or grammars[i] is None:
|
||||||
|
empty[i] = logits[i].argmax().item()
|
||||||
|
continue
|
||||||
|
|
||||||
# this is cached and should be fast after the first time
|
# this is cached and should be fast after the first time
|
||||||
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
@ -546,9 +553,7 @@ class Grammar:
|
|||||||
new_fsm_grammar_states.append(fsm_grammar_states[i])
|
new_fsm_grammar_states.append(fsm_grammar_states[i])
|
||||||
new_grammars.append(grammars[i])
|
new_grammars.append(grammars[i])
|
||||||
|
|
||||||
self.fsm_state = new_fsm_grammar_states
|
return self, new_fsm_grammar_states, new_grammars
|
||||||
self.fsm = new_grammars
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousSampling:
|
class HeterogeneousSampling:
|
||||||
|
Loading…
Reference in New Issue
Block a user