From d39e45abc351d10f62057dc280cf3f512039df87 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 13 Feb 2024 02:51:28 +0000 Subject: [PATCH] fix: adjust client and conftest for grammar --- clients/python/text_generation/client.py | 11 ++++++++--- integration-tests/conftest.py | 5 ++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 06c29ce6..d11faa2a 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -76,7 +76,7 @@ class Client: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, - grammar: str = "" + grammar: str = "", ) -> Response: """ Given a prompt, generate the following text @@ -139,7 +139,7 @@ class Client: watermark=watermark, decoder_input_details=decoder_input_details, top_n_tokens=top_n_tokens, - grammar=grammar + grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -171,6 +171,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, + grammar: str = "", ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -229,6 +230,7 @@ class Client: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -328,7 +330,7 @@ class AsyncClient: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, - grammar: str = "" + grammar: str = "", ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -373,6 +375,7 @@ class AsyncClient: Returns: Response: generated response """ + # Validate parameters parameters = Parameters( best_of=best_of, @@ -421,6 +424,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, + grammar: str = "", ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -479,6 +483,7 @@ class AsyncClient: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 74643cf7..c97b039b 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -290,12 +290,15 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + grammar_support: bool = False, dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) args = ["--model-id", model_id, "--env"] + if grammar_support: + args.append("--grammar-support") if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: @@ -378,7 +381,7 @@ def generate_load(): max_new_tokens: int, n: int, seed: Optional[int] = None, - grammar: Optional[str] = None, + grammar: Optional[str] = "", stop_sequences: Optional[List[str]] = None, ) -> List[Response]: futures = [