diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json index b7b26a2c4..f6bc6e567 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json @@ -61,7 +61,7 @@ }, { "id": 29906, - "logprob": -0.2376709, + "logprob": -0.33666992, "special": false, "text": "2" }, @@ -180,7 +180,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, @@ -299,7 +299,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, @@ -418,7 +418,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a109ee83f..ee2e11cff 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -408,7 +408,7 @@ class CausalLMBatch(Batch): batches[dst_batch_idx].next_token_chooser.dtype, batches[dst_batch_idx].next_token_chooser.device, batches[dst_batch_idx].next_token_chooser.tokenizer, - hq_env.is_quantization_enabled + quantization_enabled=hq_env.is_quantization_enabled, ) input_ids = batches[dst_batch_idx].input_ids @@ -463,7 +463,11 @@ class CausalLMBatch(Batch): parameters.append(parameters[0]) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, dtype, device, tokenizer, hq_env.is_quantization_enabled + pb=parameters, + dtype=dtype, + device=device, + tokenizer=tokenizer, + quantization_enabled=hq_env.is_quantization_enabled, ) tokenized_inputs = tokenizer( [r.data.inputs for r in requests] + dummy_inputs, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 988637d43..acd97f453 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -530,6 +530,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] next_token_chooser_parameters = [] + fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] @@ -578,6 +579,7 @@ class FlashCausalLMBatch(Batch): read_offsets.extend(batch.read_offsets) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) + fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) top_n_tokens.extend(batch.top_n_tokens) @@ -593,6 +595,7 @@ class FlashCausalLMBatch(Batch): dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, tokenizer=batches[0].next_token_chooser.tokenizer, + fsm_grammar_states=fsm_grammar_states, ) speculative_ids = ( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 3557edb1d..c879e312d 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -467,7 +467,8 @@ class HeterogeneousNextTokenChooser: dtype: torch.dtype, device: torch.device, tokenizer: PreTrainedTokenizerBase, - quantization_enabled: bool, + fsm_grammar_states: Optional[List[int]] = None, + quantization_enabled: bool = False, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -484,7 +485,9 @@ class HeterogeneousNextTokenChooser: tokenizer=tokenizer, grammars=[pb_.grammar for pb_ in pb], grammar_types=[pb_.grammar_type for pb_ in pb], - fsm_grammar_states=[0] * len(pb), + fsm_grammar_states=( + fsm_grammar_states if fsm_grammar_states else [0] * len(pb) + ), quantization_enabled=quantization_enabled, )