fix: persist grammar state after batch concat

This commit is contained in:
drbh 2024-02-28 19:44:51 +00:00
parent e6bb3ff81f
commit 4ff9cb806b
9 changed files with 18 additions and 5 deletions

View File

@ -47,6 +47,7 @@ pub async fn run(
watermark, watermark,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
grammar_state: 0,
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -61,7 +61,7 @@
}, },
{ {
"id": 29906, "id": 29906,
"logprob": -0.2376709, "logprob": -0.33666992,
"special": false, "special": false,
"text": "2" "text": "2"
}, },
@ -180,7 +180,7 @@
}, },
{ {
"id": 29906, "id": 29906,
"logprob": -0.23840332, "logprob": -0.33740234,
"special": false, "special": false,
"text": "2" "text": "2"
}, },
@ -299,7 +299,7 @@
}, },
{ {
"id": 29906, "id": 29906,
"logprob": -0.23840332, "logprob": -0.33740234,
"special": false, "special": false,
"text": "2" "text": "2"
}, },
@ -418,7 +418,7 @@
}, },
{ {
"id": 29906, "id": 29906,
"logprob": -0.23840332, "logprob": -0.33740234,
"special": false, "special": false,
"text": "2" "text": "2"
}, },

View File

@ -80,6 +80,8 @@ message NextTokenChooserParameters {
string grammar = 10; string grammar = 10;
/// grammar type /// grammar type
GrammarType grammar_type = 11; GrammarType grammar_type = 11;
/// grammar fsm state
uint32 grammar_state = 12;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -130,6 +130,7 @@ impl Client {
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
grammar_state: 0,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -48,6 +48,7 @@ impl Health {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32, grammar_type: ProtoGrammarType::None as i32,
grammar_state: 0,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -372,6 +372,7 @@ mod tests {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32, grammar_type: ProtoGrammarType::None as i32,
grammar_state: 0,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -356,6 +356,7 @@ impl Validation {
watermark, watermark,
grammar, grammar,
grammar_type, grammar_type,
grammar_state: 0,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -530,6 +530,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
next_token_chooser_parameters = [] next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
@ -578,6 +579,7 @@ class FlashCausalLMBatch(Batch):
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens) top_n_tokens.extend(batch.top_n_tokens)
@ -593,6 +595,7 @@ class FlashCausalLMBatch(Batch):
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
tokenizer=batches[0].next_token_chooser.tokenizer, tokenizer=batches[0].next_token_chooser.tokenizer,
fsm_grammar_states=fsm_grammar_states,
) )
speculative_ids = ( speculative_ids = (

View File

@ -466,7 +466,10 @@ class HeterogeneousNextTokenChooser:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fsm_grammar_states: Optional[List[int]] = None,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
if fsm_grammar_states is None:
fsm_grammar_states = [pb_.grammar_state for pb_ in pb]
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb], temperature=[pb_.temperature for pb_ in pb],
@ -482,7 +485,7 @@ class HeterogeneousNextTokenChooser:
tokenizer=tokenizer, tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb], grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb], grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=[0] * len(pb), fsm_grammar_states=fsm_grammar_states,
) )