text-generation-inference/clients/python/text_generation/client.py
Nicolas Patry 211b54ac41
Rebased #617 (#868)
# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 11:43:47 +02:00

508 lines
19 KiB
Python

import json
import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from text_generation.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation.errors import parse_error
class Client:
"""Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import Client
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.cookies = cookies
self.timeout = timeout
def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Iterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
best_of=None,
details=True,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
# Parse ServerSentEvents
for byte_payload in resp.iter_lines():
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status_code, json_payload)
yield response
class AsyncClient:
"""Asynchronous Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import AsyncClient
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
decoder_input_details=decoder_input_details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
best_of=None,
details=True,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
# Parse ServerSentEvents
async for byte_payload in resp.content:
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response