diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 6fd3365d..74643cf7 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -224,6 +224,7 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + grammar_support: bool = False, dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) @@ -247,6 +248,8 @@ def launcher(event_loop): env = os.environ + if grammar_support: + args.append("--grammar-support") if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index 7492718b..f4634fbd 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -4,7 +4,7 @@ import json @pytest.fixture(scope="module") def flash_llama_grammar_handle(launcher): - with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2) as handle: + with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True) as handle: yield handle