This commit is contained in:
OlivierDehaene 2023-03-07 12:53:03 +01:00
parent 0a27d56634
commit 6d1ad06471
7 changed files with 262 additions and 104 deletions

0
clients/python/README.md Normal file
View File

View File

@ -0,0 +1 @@
from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient

View File

@ -1,35 +1,101 @@
import json
import os
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Optional
from typing import Dict, Optional, List, AsyncIterator
from text_generation_inference.types import StreamResponse, ErrorModel, Response
from text_generation_inference import SUPPORTED_MODELS
from text_generation_inference.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation_inference.errors import parse_error, NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
class AsyncClient:
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
self.model_id = model_id
def __init__(
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
):
self.base_url = base_url
self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
self.session = ClientSession(headers=headers, timeout=ClientTimeout(timeout * 60))
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async def generate(self):
async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}",
json={"inputs": "test", "stream": True}) as resp:
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
error = ErrorModel(**await resp.json())
raise error.to_exception()
return Response(**await resp.json())
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(self):
async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}",
json={"inputs": "test", "stream": True}) as resp:
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
error = ErrorModel(**await resp.json())
raise error.to_exception()
raise parse_error(resp.status, await resp.json())
async for byte_payload in resp.content:
if byte_payload == b"\n":
@ -42,20 +108,33 @@ class AsyncClient:
try:
response = StreamResponse(**json_payload)
except ValidationError:
error = ErrorModel(**json_payload)
raise error.to_exception()
yield response.token
def __del__(self):
self.session.close()
raise parse_error(resp.status, json_payload)
yield response
class APIInferenceAsyncClient(AsyncClient):
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
# Text Generation Inference client only supports a subset of the available hub models
if model_id not in SUPPORTED_MODELS:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
if __name__ == "__main__":
async def main():
client = AsyncClient("bigscience/bloomz")
async for token in client.generate_stream():
client = APIInferenceAsyncClient(
"bigscience/bloomz", token="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"
)
async for token in client.generate_stream("test"):
print(token)
print(await client.generate())
print(await client.generate("test"))
import asyncio

View File

@ -0,0 +1,93 @@
from typing import Dict
# Text Generation Inference Errors
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
# API Inference Errors
class BadRequestError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardNotReadyError(Exception):
def __init__(self, message: str):
super().__init__(message)
class TimeoutError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotFoundError(Exception):
def __init__(self, message: str):
super().__init__(message)
class RateLimitExceededError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotSupportedError(Exception):
def __init__(self, model_id: str):
message = (
f"Model `{model_id}` is not available for inference with this client. \n"
"Use `huggingface_hub.inference_api.InferenceApi` instead."
)
super(NotSupportedError, self).__init__(message)
# Unknown error
class UnknownError(Exception):
def __init__(self, message: str):
super().__init__(message)
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
# Try to parse a Text Generation Inference error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
# Try to parse a APIInference error
if status_code == 400:
return BadRequestError(message)
if status_code == 403 or status_code == 424:
return ShardNotReadyError(message)
if status_code == 504:
return TimeoutError(message)
if status_code == 404:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
# Fallback to an unknown error
return UnknownError(message)

View File

@ -1,59 +1,8 @@
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List, Type
from typing import Optional, List
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class InferenceAPIError(Exception):
def __init__(self, message: str):
super(InferenceAPIError, self).__init__(message)
class ErrorType(str, Enum):
generation = "generation"
incomplete_generation = "incomplete_generation"
overloaded = "overloaded"
validation = "validation"
def to_exception_type(self) -> Type[Exception]:
if self == ErrorType.generation:
return GenerationError
if self == ErrorType.incomplete_generation:
return IncompleteGenerationError
if self == ErrorType.overloaded:
return OverloadedError
if self == ErrorType.validation:
return ValidationError
raise ValueError("Unknown error")
class ErrorModel(BaseModel):
error_type: Optional[ErrorType]
error: str
def to_exception(self) -> Exception:
if self.error_type is not None:
return self.error_type.to_exception_type()(self.error)
return InferenceAPIError(self.error)
from text_generation_inference.errors import ValidationError
class Parameters(BaseModel):
@ -61,12 +10,13 @@ class Parameters(BaseModel):
max_new_tokens: int = 20
repetition_penalty: Optional[float] = None
return_full_text: bool = False
stop: List[str]
seed: Optional[int]
stop: Optional[List[str]]
temperature: Optional[float]
top_k: Optional[int]
top_p: Optional[float]
watermark: bool = False
details: bool = False
@validator("seed")
def valid_seed(cls, v):
@ -93,8 +43,16 @@ class Parameters(BaseModel):
return v
class Response(BaseModel):
generated_text: str
class Request(BaseModel):
inputs: str
parameters: Parameters
stream: bool = False
class PrefillToken(BaseModel):
id: int
text: str
logprob: Optional[float]
class Token(BaseModel):
@ -104,5 +62,32 @@ class Token(BaseModel):
special: bool
class FinishReason(Enum):
Length = "length"
EndOfSequenceToken = "eos_token"
StopSequence = "stop_sequence"
class Details(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
prefill: List[PrefillToken]
tokens: List[Token]
class StreamDetails(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
class Response(BaseModel):
generated_text: str
details: Details
class StreamResponse(BaseModel):
token: Token
generated_text: Optional[str]
details: Optional[StreamDetails]

View File

@ -152,8 +152,8 @@ pub(crate) struct Details {
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
pub prefill: Option<Vec<PrefillToken>>,
pub tokens: Option<Vec<Token>>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
}
#[derive(Serialize, ToSchema)]

View File

@ -133,8 +133,8 @@ async fn generate(
true => Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill),
tokens: Some(response.tokens),
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
}),
false => None,