From e672f976fb16f161b87eec54b6aae6f48498ed2d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Feb 2024 16:31:01 +0100 Subject: [PATCH] Fix . --- integration-tests/conftest.py | 4 ++++ integration-tests/models/test_flash_medusa.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 80457bc2..6b7a894c 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -236,6 +236,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + revision: Optional[str] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -268,6 +269,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) if trust_remote_code: args.append("--trust-remote-code") diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index f90d1d9a..27db5665 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def flash_medusa_handle(launcher): - with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1") as handle: + with launcher( + "FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1" + ) as handle: yield handle