mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: adjust client and conftest for grammar
This commit is contained in:
parent
5ba1baccb0
commit
d39e45abc3
@ -76,7 +76,7 @@ class Client:
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
grammar: str = ""
|
||||
grammar: str = "",
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text
|
||||
@ -139,7 +139,7 @@ class Client:
|
||||
watermark=watermark,
|
||||
decoder_input_details=decoder_input_details,
|
||||
top_n_tokens=top_n_tokens,
|
||||
grammar=grammar
|
||||
grammar=grammar,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
@ -171,6 +171,7 @@ class Client:
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
grammar: str = "",
|
||||
) -> Iterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens
|
||||
@ -229,6 +230,7 @@ class Client:
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
grammar=grammar,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
@ -328,7 +330,7 @@ class AsyncClient:
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
grammar: str = ""
|
||||
grammar: str = "",
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
@ -373,6 +375,7 @@ class AsyncClient:
|
||||
Returns:
|
||||
Response: generated response
|
||||
"""
|
||||
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
best_of=best_of,
|
||||
@ -421,6 +424,7 @@ class AsyncClient:
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
grammar: str = "",
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens asynchronously
|
||||
@ -479,6 +483,7 @@ class AsyncClient:
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
grammar=grammar,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
|
@ -290,12 +290,15 @@ def launcher(event_loop):
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_flash_attention: bool = True,
|
||||
grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
args = ["--model-id", model_id, "--env"]
|
||||
|
||||
if grammar_support:
|
||||
args.append("--grammar-support")
|
||||
if num_shard is not None:
|
||||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize is not None:
|
||||
@ -378,7 +381,7 @@ def generate_load():
|
||||
max_new_tokens: int,
|
||||
n: int,
|
||||
seed: Optional[int] = None,
|
||||
grammar: Optional[str] = None,
|
||||
grammar: Optional[str] = "",
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
) -> List[Response]:
|
||||
futures = [
|
||||
|
Loading…
Reference in New Issue
Block a user