diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 77e738dd..16d2c408 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -17,4 +17,4 @@ def default_pb_parameters(): @pytest.fixture def default_pb_stop_parameters(): - return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) \ No newline at end of file + return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) diff --git a/server/tests/models/test_grammar.py b/server/tests/models/test_grammar.py new file mode 100644 index 00000000..b5e65620 --- /dev/null +++ b/server/tests/models/test_grammar.py @@ -0,0 +1,245 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import json +import pytest +import torch + +from copy import copy + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models import get_model +from text_generation_server.models.causal_lm import ( + CausalLMBatch, + PAD_SEQUENCE_TO_MULTIPLE_OF, +) + +PAD_TOKEN=0 + + +@pytest.fixture +def default_pb_grammar_parameters(): + grammar_schema = { + "properties": { + "activity": { + "type": "string" + }, + "animals": { + "items": { + "type":"string" + }, + "type": "array" + } + }, + "required": ["activity", "animals"] + } + return generate_pb2.NextTokenChooserParameters( + temperature=1.0, + repetition_penalty=1.3, + top_k=0, + top_p=1.0, + typical_p=1.0, + do_sample=False, + grammar_type=generate_pb2.GrammarType.GRAMMAR_TYPE_JSON, + grammar=json.dumps(grammar_schema).encode('utf-8'), + ) + + +@pytest.fixture(scope="session") +def default_grammar_response(): + return [ + 29912, 376, 29874, 312, 2068, 1115, 29871, 13, 29908, 29890, + 638, 292, 613, 259, 376, 273, 3039, 29879, 1115,518, 1678, + 376, 26169, 3284, 4117, 3284, 336, 617, 6150, 3108, 500, 2 + ] + + +@pytest.fixture(scope="session") +def default_causal_lm(): + return get_model("meta-llama/Llama-2-7b-hf", None, None, None, None) + + +@pytest.fixture(scope="session") +def default_tokenizer(default_causal_lm): + default_causal_lm.tokenizer.pad_token_id = PAD_TOKEN + return default_causal_lm.tokenizer + + +@pytest.fixture +def default_pb_request(default_pb_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + prefill_logprobs=True, + truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, + parameters=default_pb_parameters, + stopping_parameters=generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10), + ) + + +@pytest.fixture +def default_pb_grammar_request(default_pb_grammar_parameters): + return generate_pb2.Request( + id=1, + inputs=f"Please use the following JSON schema to generate the output: I saw a puppy a cat and a raccoon during my bike ride in the park", + prefill_logprobs=True, + truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, + parameters=default_pb_grammar_parameters, + stopping_parameters=generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=50), + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_pb_grammar_batch(default_pb_grammar_request): + return generate_pb2.Batch(id=1, requests=[default_pb_grammar_request], size=1) + + +@pytest.fixture +def default_causal_lm_batch(default_pb_batch, default_tokenizer): + return CausalLMBatch.from_pb( + default_pb_batch, default_tokenizer, torch.float32, torch.device("hpu") + ) + + +@pytest.fixture +def default_causal_lm_grammar_batch(default_pb_grammar_batch, default_tokenizer): + return CausalLMBatch.from_pb( + default_pb_grammar_batch, default_tokenizer, torch.float32, torch.device("hpu") + ) + + +@pytest.fixture +def default_two_causal_lm_grammar_batches(default_pb_grammar_request, default_tokenizer): + req_0 = default_pb_grammar_request + req_0.id = 0 + req_1 = copy(default_pb_grammar_request) + req_1.id = 1 + + batch_0 = generate_pb2.Batch(id=0, requests=[req_0], size=1) + batch_1 = generate_pb2.Batch(id=1, requests=[req_1], size=1) + return [ + CausalLMBatch.from_pb( + b, default_tokenizer, torch.float32, torch.device("hpu") + ) for b in [batch_0, batch_1] + ] + + +def test_single_grammar_batch( + default_causal_lm, default_causal_lm_grammar_batch, default_grammar_response +): + counter = 0 + batch = default_causal_lm_grammar_batch + + # prefill request + generations, next_batch, _ = default_causal_lm.generate_token([batch]) + + # generate untill done + while next_batch is not None: + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 1 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] + counter += 1 + print(generations[0].generated_text.text) + + +def test_multi_grammar_batches( + default_causal_lm, default_two_causal_lm_grammar_batches, default_grammar_response +): + counter_0, counter_1 = 0, 0 + batch_0, batch_1 = default_two_causal_lm_grammar_batches + + # prefill first request + generations, next_batch, _ = default_causal_lm.generate_token([batch_0]) + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 1 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0] + counter_0 += 1 + + # prefill second request + generations, next_batch_1, _ = default_causal_lm.generate_token([batch_1]) + + # concatenate and generate + generations, next_batch, _ = default_causal_lm.generate_token([next_batch, next_batch_1]) + assert len(generations) == 2 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0] + assert generations[1].tokens.token_ids[0] == default_grammar_response[counter_1] + counter_0 += 1 + counter_1 += 1 + + # generate untill first request is done + while generations[0].generated_text is None: + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 2 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0] + assert generations[1].tokens.token_ids[0] == default_grammar_response[counter_1] + counter_0 += 1 + counter_1 += 1 + + # filter finished request + response = generations[0].generated_text.text + next_batch = next_batch.filter([next_batch.requests[1].data.id]) + + # generate last tokens for second request + while next_batch is not None: + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 1 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_1] + counter_1 += 1 + + assert response == generations[0].generated_text.text + + +def test_grammar_and_causal_batch( + default_causal_lm, default_causal_lm_grammar_batch, default_causal_lm_batch, default_grammar_response +): + counter = 0 + generations, next_batch, _ = default_causal_lm.generate_token([default_causal_lm_grammar_batch]) + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 1 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] + counter += 1 + + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 1 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] + counter += 1 + + # prefill second request + generations, next_batch_1, _ = default_causal_lm.generate_token([default_causal_lm_batch]) + + # concatenate and generate + generations, next_batch, _ = default_causal_lm.generate_token([next_batch, next_batch_1]) + assert len(generations) == 2 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] + counter += 1 + + # generate untill second request is done + for _ in range( + next_batch.requests[1].stopping_criteria.max_new_tokens - 1 + ): + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 2 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] + counter += 1 + + # filter finished request + assert len(generations) == 2 + assert ( + generations[1].request_id == next_batch.requests[1].data.id + ) + assert ( + generations[1].generated_text.generated_tokens == next_batch.requests[1].stopping_criteria.max_new_tokens + ) + assert generations[1].generated_text.text == "ing the effect of a new method for the detection" + next_batch = next_batch.filter([next_batch.requests[0].data.id]) + + # generate untill done + while next_batch is not None: + generations, next_batch, _ = default_causal_lm.generate_token([next_batch]) + assert len(generations) == 1 + assert generations[0].tokens.token_ids[0] == default_grammar_response[counter] + counter += 1 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 65ba35b9..c99bc79a 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -358,7 +358,7 @@ class CausalLMBatch(Batch): moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = (batches[dst_batch_idx].batch_size != new_bs) + reshape = (batches[dst_batch_idx].batch_size < new_bs) # TODO: Add support for changing max seq len, i.e. due to output length bucketing # FIXME: max_seq_len for non optimized code @@ -397,16 +397,23 @@ class CausalLMBatch(Batch): top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) parameters = [r.data.parameters for r in flat_requests] - if len(flat_requests) < new_bs: - for i in range(new_bs-len(flat_requests)) : - # append the dummy parameters for dummy request - parameters.append(parameters[0]) + # append the dummy parameters for dummy requests + batch_size = batches[dst_batch_idx].batch_size + parameters.extend( + [generate_pb2.NextTokenChooserParameters()] * (batch_size - len(flat_requests)) + ) + + fsm_grammar_states = [0] * batch_size + for batch in batches: + for i, req in enumerate(batch.requests): + fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( parameters, batches[dst_batch_idx].next_token_chooser.dtype, batches[dst_batch_idx].next_token_chooser.device, batches[dst_batch_idx].next_token_chooser.tokenizer, + fsm_grammar_states, quantization_enabled=hq_env.is_quantization_enabled, ) @@ -454,12 +461,13 @@ class CausalLMBatch(Batch): # this means that we cannot shift inputs to the left after a long input sequence # was filtered out new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) - dummy_inputs = ["?"] * (new_bs - len(requests)) + missing_inputs = new_bs - len(requests) + dummy_inputs = ["?"] * missing_inputs parameters = [r.parameters for r in pb.requests] - if len(pb.requests) < new_bs: - for i in range(new_bs-len(pb.requests)) : - #append the dummy parameters for dummy request - parameters.append(parameters[0]) + # append the dummy parameters for dummy request + parameters.extend( + [generate_pb2.NextTokenChooserParameters()] * missing_inputs + ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( pb=parameters, @@ -889,6 +897,7 @@ class CausalLM(Model): 'top_n_tokens': batch.top_n_tokens[req_idx], 'top_token_ids': batch_top_token_ids[req_idx], 'top_token_logprobs': batch_top_token_logprobs[req_idx], + 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], }) htorch.core.mark_step() @@ -986,6 +995,7 @@ class CausalLM(Model): top_n_tokens = req_data['top_n_tokens'] top_token_ids = req_data['top_token_ids'] top_token_logprobs = req_data['top_token_logprobs'] + grammar_state = req_data['grammar_state'] # Append next token to all tokens all_input_ids[input_length] = next_token_id @@ -1087,6 +1097,12 @@ class CausalLM(Model): generations.append(generation) + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single_with_past_state( + req.idx, next_token_id, grammar_state + ) + ) + req.all_input_ids = all_input_ids req.input_length = new_input_length req.prefix_offset = prefix_offset diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 312583e3..b9da9c14 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -542,7 +542,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): mask = torch.full_like(logits, -math.inf) for i in range(logits.shape[0]): fsm = self.fsms[i] - if fsm_grammar_states[i] == -1 or fsm is None: + if fsm is None: continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) mask[i, allowed_tokens] = 0 diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c879e312..ef445964 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -418,6 +418,18 @@ class HeterogeneousNextTokenChooser: ) return self + def advance_grammar_single_with_past_state( + self, grammar_state_index: int, next_id: torch.Tensor, past_state: int + ): + if self.grammar_processor is not None: + next_id = next_id.item() + self.fsm_grammar_states[grammar_state_index] = ( + self.grammar_processor.advance_at_index( + next_id, past_state, grammar_state_index, + ) + ) + return self + def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices)