This commit is contained in:
Nicolas Patry 2024-02-26 16:31:01 +01:00
parent fa40801fb6
commit e672f976fb
2 changed files with 7 additions and 1 deletions

View File

@ -236,6 +236,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -268,6 +269,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_medusa_handle(launcher): def flash_medusa_handle(launcher):
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1") as handle: with launcher(
"FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1"
) as handle:
yield handle yield handle