From 4ff9cb806b53002ea6a36456525138ee17e163a3 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 28 Feb 2024 19:44:51 +0000 Subject: [PATCH] fix: persist grammar state after batch concat --- benchmark/src/lib.rs | 1 + .../test_grammar_llama/test_flash_llama_grammar_load.json | 8 ++++---- proto/generate.proto | 2 ++ router/client/src/client.rs | 1 + router/src/health.rs | 1 + router/src/queue.rs | 1 + router/src/validation.rs | 1 + server/text_generation_server/models/flash_causal_lm.py | 3 +++ server/text_generation_server/utils/tokens.py | 5 ++++- 9 files changed, 18 insertions(+), 5 deletions(-) diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 638c6514..034c056c 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -47,6 +47,7 @@ pub async fn run( watermark, grammar: String::new(), grammar_type: GrammarType::None as i32, + grammar_state: 0, }; // Initialize terminal properties diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json index b7b26a2c..f6bc6e56 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json @@ -61,7 +61,7 @@ }, { "id": 29906, - "logprob": -0.2376709, + "logprob": -0.33666992, "special": false, "text": "2" }, @@ -180,7 +180,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, @@ -299,7 +299,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, @@ -418,7 +418,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f..59bf253e 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -80,6 +80,8 @@ message NextTokenChooserParameters { string grammar = 10; /// grammar type GrammarType grammar_type = 11; + /// grammar fsm state + uint32 grammar_state = 12; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index f8658318..a8f7a499 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -130,6 +130,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + grammar_state: 0, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/src/health.rs b/router/src/health.rs index b05b3094..8e5cbf02 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -48,6 +48,7 @@ impl Health { watermark: false, grammar: String::new(), grammar_type: ProtoGrammarType::None as i32, + grammar_state: 0, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/queue.rs b/router/src/queue.rs index 52ea16ca..69d0aba8 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -372,6 +372,7 @@ mod tests { watermark: false, grammar: String::new(), grammar_type: ProtoGrammarType::None as i32, + grammar_state: 0, }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index 204dbf92..f829395f 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -356,6 +356,7 @@ impl Validation { watermark, grammar, grammar_type, + grammar_state: 0, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 988637d4..acd97f45 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -530,6 +530,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] next_token_chooser_parameters = [] + fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] @@ -578,6 +579,7 @@ class FlashCausalLMBatch(Batch): read_offsets.extend(batch.read_offsets) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) + fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) top_n_tokens.extend(batch.top_n_tokens) @@ -593,6 +595,7 @@ class FlashCausalLMBatch(Batch): dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, tokenizer=batches[0].next_token_chooser.tokenizer, + fsm_grammar_states=fsm_grammar_states, ) speculative_ids = ( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 32789850..c64d2664 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -466,7 +466,10 @@ class HeterogeneousNextTokenChooser: dtype: torch.dtype, device: torch.device, tokenizer: PreTrainedTokenizerBase, + fsm_grammar_states: Optional[List[int]] = None, ) -> "HeterogeneousNextTokenChooser": + if fsm_grammar_states is None: + fsm_grammar_states = [pb_.grammar_state for pb_ in pb] return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], @@ -482,7 +485,7 @@ class HeterogeneousNextTokenChooser: tokenizer=tokenizer, grammars=[pb_.grammar for pb_ in pb], grammar_types=[pb_.grammar_type for pb_ in pb], - fsm_grammar_states=[0] * len(pb), + fsm_grammar_states=fsm_grammar_states, )