mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: persist grammar state after batch concat
This commit is contained in:
parent
e6bb3ff81f
commit
4ff9cb806b
@ -47,6 +47,7 @@ pub async fn run(
|
||||
watermark,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
grammar_state: 0,
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
|
@ -61,7 +61,7 @@
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.2376709,
|
||||
"logprob": -0.33666992,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
@ -180,7 +180,7 @@
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.23840332,
|
||||
"logprob": -0.33740234,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
@ -299,7 +299,7 @@
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.23840332,
|
||||
"logprob": -0.33740234,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
@ -418,7 +418,7 @@
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.23840332,
|
||||
"logprob": -0.33740234,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
|
@ -80,6 +80,8 @@ message NextTokenChooserParameters {
|
||||
string grammar = 10;
|
||||
/// grammar type
|
||||
GrammarType grammar_type = 11;
|
||||
/// grammar fsm state
|
||||
uint32 grammar_state = 12;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -130,6 +130,7 @@ impl Client {
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
grammar_state: 0,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
|
@ -48,6 +48,7 @@ impl Health {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
grammar_state: 0,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -372,6 +372,7 @@ mod tests {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
grammar_state: 0,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -356,6 +356,7 @@ impl Validation {
|
||||
watermark,
|
||||
grammar,
|
||||
grammar_type,
|
||||
grammar_state: 0,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -530,6 +530,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets = []
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
fsm_grammar_states = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
@ -578,6 +579,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
@ -593,6 +595,7 @@ class FlashCausalLMBatch(Batch):
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
device=batches[0].next_token_chooser.device,
|
||||
tokenizer=batches[0].next_token_chooser.tokenizer,
|
||||
fsm_grammar_states=fsm_grammar_states,
|
||||
)
|
||||
|
||||
speculative_ids = (
|
||||
|
@ -466,7 +466,10 @@ class HeterogeneousNextTokenChooser:
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fsm_grammar_states: Optional[List[int]] = None,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
if fsm_grammar_states is None:
|
||||
fsm_grammar_states = [pb_.grammar_state for pb_ in pb]
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
@ -482,7 +485,7 @@ class HeterogeneousNextTokenChooser:
|
||||
tokenizer=tokenizer,
|
||||
grammars=[pb_.grammar for pb_ in pb],
|
||||
grammar_types=[pb_.grammar_type for pb_ in pb],
|
||||
fsm_grammar_states=[0] * len(pb),
|
||||
fsm_grammar_states=fsm_grammar_states,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user