fix: adjust client and conftest for grammar

This commit is contained in:
drbh 2024-02-13 02:51:28 +00:00
parent 5ba1baccb0
commit d39e45abc3
2 changed files with 12 additions and 4 deletions

View File

@ -76,7 +76,7 @@ class Client:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "" grammar: str = "",
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -139,7 +139,7 @@ class Client:
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar grammar=grammar,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -171,6 +171,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "",
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -229,6 +230,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -328,7 +330,7 @@ class AsyncClient:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "" grammar: str = "",
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -373,6 +375,7 @@ class AsyncClient:
Returns: Returns:
Response: generated response Response: generated response
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
@ -421,6 +424,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "",
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -479,6 +483,7 @@ class AsyncClient:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -290,12 +290,15 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"] args = ["--model-id", model_id, "--env"]
if grammar_support:
args.append("--grammar-support")
if num_shard is not None: if num_shard is not None:
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize is not None: if quantize is not None:
@ -378,7 +381,7 @@ def generate_load():
max_new_tokens: int, max_new_tokens: int,
n: int, n: int,
seed: Optional[int] = None, seed: Optional[int] = None,
grammar: Optional[str] = None, grammar: Optional[str] = "",
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
) -> List[Response]: ) -> List[Response]:
futures = [ futures = [