mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-01 15:02:09 +00:00
Add grammar support (#140)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
16f9ff8965
commit
32acdd55b4
@ -17,4 +17,4 @@ def default_pb_parameters():
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_stop_parameters():
|
||||
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
|
||||
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
|
||||
|
245
server/tests/models/test_grammar.py
Normal file
245
server/tests/models/test_grammar.py
Normal file
@ -0,0 +1,245 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import json
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from copy import copy
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models import get_model
|
||||
from text_generation_server.models.causal_lm import (
|
||||
CausalLMBatch,
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
)
|
||||
|
||||
PAD_TOKEN=0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_grammar_parameters():
|
||||
grammar_schema = {
|
||||
"properties": {
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals": {
|
||||
"items": {
|
||||
"type":"string"
|
||||
},
|
||||
"type": "array"
|
||||
}
|
||||
},
|
||||
"required": ["activity", "animals"]
|
||||
}
|
||||
return generate_pb2.NextTokenChooserParameters(
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.3,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
typical_p=1.0,
|
||||
do_sample=False,
|
||||
grammar_type=generate_pb2.GrammarType.GRAMMAR_TYPE_JSON,
|
||||
grammar=json.dumps(grammar_schema).encode('utf-8'),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_grammar_response():
|
||||
return [
|
||||
29912, 376, 29874, 312, 2068, 1115, 29871, 13, 29908, 29890,
|
||||
638, 292, 613, 259, 376, 273, 3039, 29879, 1115,518, 1678,
|
||||
376, 26169, 3284, 4117, 3284, 336, 617, 6150, 3108, 500, 2
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_causal_lm():
|
||||
return get_model("meta-llama/Llama-2-7b-hf", None, None, None, None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_tokenizer(default_causal_lm):
|
||||
default_causal_lm.tokenizer.pad_token_id = PAD_TOKEN
|
||||
return default_causal_lm.tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
prefill_logprobs=True,
|
||||
truncate=PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_grammar_request(default_pb_grammar_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=1,
|
||||
inputs=f"Please use the following JSON schema to generate the output: I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||
prefill_logprobs=True,
|
||||
truncate=PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
parameters=default_pb_grammar_parameters,
|
||||
stopping_parameters=generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=50),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_batch(default_pb_request):
|
||||
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_grammar_batch(default_pb_grammar_request):
|
||||
return generate_pb2.Batch(id=1, requests=[default_pb_grammar_request], size=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_causal_lm_batch(default_pb_batch, default_tokenizer):
|
||||
return CausalLMBatch.from_pb(
|
||||
default_pb_batch, default_tokenizer, torch.float32, torch.device("hpu")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_causal_lm_grammar_batch(default_pb_grammar_batch, default_tokenizer):
|
||||
return CausalLMBatch.from_pb(
|
||||
default_pb_grammar_batch, default_tokenizer, torch.float32, torch.device("hpu")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_two_causal_lm_grammar_batches(default_pb_grammar_request, default_tokenizer):
|
||||
req_0 = default_pb_grammar_request
|
||||
req_0.id = 0
|
||||
req_1 = copy(default_pb_grammar_request)
|
||||
req_1.id = 1
|
||||
|
||||
batch_0 = generate_pb2.Batch(id=0, requests=[req_0], size=1)
|
||||
batch_1 = generate_pb2.Batch(id=1, requests=[req_1], size=1)
|
||||
return [
|
||||
CausalLMBatch.from_pb(
|
||||
b, default_tokenizer, torch.float32, torch.device("hpu")
|
||||
) for b in [batch_0, batch_1]
|
||||
]
|
||||
|
||||
|
||||
def test_single_grammar_batch(
|
||||
default_causal_lm, default_causal_lm_grammar_batch, default_grammar_response
|
||||
):
|
||||
counter = 0
|
||||
batch = default_causal_lm_grammar_batch
|
||||
|
||||
# prefill request
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([batch])
|
||||
|
||||
# generate untill done
|
||||
while next_batch is not None:
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 1
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter]
|
||||
counter += 1
|
||||
print(generations[0].generated_text.text)
|
||||
|
||||
|
||||
def test_multi_grammar_batches(
|
||||
default_causal_lm, default_two_causal_lm_grammar_batches, default_grammar_response
|
||||
):
|
||||
counter_0, counter_1 = 0, 0
|
||||
batch_0, batch_1 = default_two_causal_lm_grammar_batches
|
||||
|
||||
# prefill first request
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([batch_0])
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 1
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0]
|
||||
counter_0 += 1
|
||||
|
||||
# prefill second request
|
||||
generations, next_batch_1, _ = default_causal_lm.generate_token([batch_1])
|
||||
|
||||
# concatenate and generate
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch, next_batch_1])
|
||||
assert len(generations) == 2
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0]
|
||||
assert generations[1].tokens.token_ids[0] == default_grammar_response[counter_1]
|
||||
counter_0 += 1
|
||||
counter_1 += 1
|
||||
|
||||
# generate untill first request is done
|
||||
while generations[0].generated_text is None:
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 2
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_0]
|
||||
assert generations[1].tokens.token_ids[0] == default_grammar_response[counter_1]
|
||||
counter_0 += 1
|
||||
counter_1 += 1
|
||||
|
||||
# filter finished request
|
||||
response = generations[0].generated_text.text
|
||||
next_batch = next_batch.filter([next_batch.requests[1].data.id])
|
||||
|
||||
# generate last tokens for second request
|
||||
while next_batch is not None:
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 1
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter_1]
|
||||
counter_1 += 1
|
||||
|
||||
assert response == generations[0].generated_text.text
|
||||
|
||||
|
||||
def test_grammar_and_causal_batch(
|
||||
default_causal_lm, default_causal_lm_grammar_batch, default_causal_lm_batch, default_grammar_response
|
||||
):
|
||||
counter = 0
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([default_causal_lm_grammar_batch])
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 1
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter]
|
||||
counter += 1
|
||||
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 1
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter]
|
||||
counter += 1
|
||||
|
||||
# prefill second request
|
||||
generations, next_batch_1, _ = default_causal_lm.generate_token([default_causal_lm_batch])
|
||||
|
||||
# concatenate and generate
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch, next_batch_1])
|
||||
assert len(generations) == 2
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter]
|
||||
counter += 1
|
||||
|
||||
# generate untill second request is done
|
||||
for _ in range(
|
||||
next_batch.requests[1].stopping_criteria.max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 2
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter]
|
||||
counter += 1
|
||||
|
||||
# filter finished request
|
||||
assert len(generations) == 2
|
||||
assert (
|
||||
generations[1].request_id == next_batch.requests[1].data.id
|
||||
)
|
||||
assert (
|
||||
generations[1].generated_text.generated_tokens == next_batch.requests[1].stopping_criteria.max_new_tokens
|
||||
)
|
||||
assert generations[1].generated_text.text == "ing the effect of a new method for the detection"
|
||||
next_batch = next_batch.filter([next_batch.requests[0].data.id])
|
||||
|
||||
# generate untill done
|
||||
while next_batch is not None:
|
||||
generations, next_batch, _ = default_causal_lm.generate_token([next_batch])
|
||||
assert len(generations) == 1
|
||||
assert generations[0].tokens.token_ids[0] == default_grammar_response[counter]
|
||||
counter += 1
|
@ -358,7 +358,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
||||
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||
reshape = (batches[dst_batch_idx].batch_size != new_bs)
|
||||
reshape = (batches[dst_batch_idx].batch_size < new_bs)
|
||||
|
||||
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
||||
# FIXME: max_seq_len for non optimized code
|
||||
@ -397,16 +397,23 @@ class CausalLMBatch(Batch):
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
|
||||
parameters = [r.data.parameters for r in flat_requests]
|
||||
if len(flat_requests) < new_bs:
|
||||
for i in range(new_bs-len(flat_requests)) :
|
||||
# append the dummy parameters for dummy request
|
||||
parameters.append(parameters[0])
|
||||
# append the dummy parameters for dummy requests
|
||||
batch_size = batches[dst_batch_idx].batch_size
|
||||
parameters.extend(
|
||||
[generate_pb2.NextTokenChooserParameters()] * (batch_size - len(flat_requests))
|
||||
)
|
||||
|
||||
fsm_grammar_states = [0] * batch_size
|
||||
for batch in batches:
|
||||
for i, req in enumerate(batch.requests):
|
||||
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
parameters,
|
||||
batches[dst_batch_idx].next_token_chooser.dtype,
|
||||
batches[dst_batch_idx].next_token_chooser.device,
|
||||
batches[dst_batch_idx].next_token_chooser.tokenizer,
|
||||
fsm_grammar_states,
|
||||
quantization_enabled=hq_env.is_quantization_enabled,
|
||||
)
|
||||
|
||||
@ -454,12 +461,13 @@ class CausalLMBatch(Batch):
|
||||
# this means that we cannot shift inputs to the left after a long input sequence
|
||||
# was filtered out
|
||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||
dummy_inputs = ["?"] * (new_bs - len(requests))
|
||||
missing_inputs = new_bs - len(requests)
|
||||
dummy_inputs = ["?"] * missing_inputs
|
||||
parameters = [r.parameters for r in pb.requests]
|
||||
if len(pb.requests) < new_bs:
|
||||
for i in range(new_bs-len(pb.requests)) :
|
||||
#append the dummy parameters for dummy request
|
||||
parameters.append(parameters[0])
|
||||
# append the dummy parameters for dummy request
|
||||
parameters.extend(
|
||||
[generate_pb2.NextTokenChooserParameters()] * missing_inputs
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
pb=parameters,
|
||||
@ -889,6 +897,7 @@ class CausalLM(Model):
|
||||
'top_n_tokens': batch.top_n_tokens[req_idx],
|
||||
'top_token_ids': batch_top_token_ids[req_idx],
|
||||
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
||||
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
|
||||
})
|
||||
|
||||
htorch.core.mark_step()
|
||||
@ -986,6 +995,7 @@ class CausalLM(Model):
|
||||
top_n_tokens = req_data['top_n_tokens']
|
||||
top_token_ids = req_data['top_token_ids']
|
||||
top_token_logprobs = req_data['top_token_logprobs']
|
||||
grammar_state = req_data['grammar_state']
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids[input_length] = next_token_id
|
||||
@ -1087,6 +1097,12 @@ class CausalLM(Model):
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
batch.next_token_chooser = (
|
||||
batch.next_token_chooser.advance_grammar_single_with_past_state(
|
||||
req.idx, next_token_id, grammar_state
|
||||
)
|
||||
)
|
||||
|
||||
req.all_input_ids = all_input_ids
|
||||
req.input_length = new_input_length
|
||||
req.prefix_offset = prefix_offset
|
||||
|
@ -542,7 +542,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
mask = torch.full_like(logits, -math.inf)
|
||||
for i in range(logits.shape[0]):
|
||||
fsm = self.fsms[i]
|
||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||
if fsm is None:
|
||||
continue
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
mask[i, allowed_tokens] = 0
|
||||
|
@ -418,6 +418,18 @@ class HeterogeneousNextTokenChooser:
|
||||
)
|
||||
return self
|
||||
|
||||
def advance_grammar_single_with_past_state(
|
||||
self, grammar_state_index: int, next_id: torch.Tensor, past_state: int
|
||||
):
|
||||
if self.grammar_processor is not None:
|
||||
next_id = next_id.item()
|
||||
self.fsm_grammar_states[grammar_state_index] = (
|
||||
self.grammar_processor.advance_at_index(
|
||||
next_id, past_state, grammar_state_index,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||
|
Loading…
Reference in New Issue
Block a user