This commit is contained in:
OlivierDehaene 2023-03-07 12:53:03 +01:00
parent 0a27d56634
commit 6d1ad06471
7 changed files with 262 additions and 104 deletions

0
clients/python/README.md Normal file
View File

View File

@ -0,0 +1 @@
from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient

View File

@ -1,35 +1,101 @@
import json import json
import os
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError from pydantic import ValidationError
from typing import Optional from typing import Dict, Optional, List, AsyncIterator
from text_generation_inference.types import StreamResponse, ErrorModel, Response from text_generation_inference import SUPPORTED_MODELS
from text_generation_inference.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation_inference.errors import parse_error, NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
class AsyncClient: class AsyncClient:
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10): def __init__(
headers = {} self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
if token is not None: ):
headers = {"Authorization": f"Bearer {token}"} self.base_url = base_url
self.model_id = model_id self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
self.session = ClientSession(headers=headers, timeout=ClientTimeout(timeout * 60)) async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async def generate(self): async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}", async with session.post(self.base_url, json=request.dict()) as resp:
json={"inputs": "test", "stream": True}) as resp: payload = await resp.json()
if resp.status != 200: if resp.status != 200:
error = ErrorModel(**await resp.json()) raise parse_error(resp.status, payload)
raise error.to_exception() return Response(**payload[0])
return Response(**await resp.json())
async def generate_stream(self): async def generate_stream(
async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}", self,
json={"inputs": "test", "stream": True}) as resp: prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200: if resp.status != 200:
error = ErrorModel(**await resp.json()) raise parse_error(resp.status, await resp.json())
raise error.to_exception()
async for byte_payload in resp.content: async for byte_payload in resp.content:
if byte_payload == b"\n": if byte_payload == b"\n":
@ -42,22 +108,35 @@ class AsyncClient:
try: try:
response = StreamResponse(**json_payload) response = StreamResponse(**json_payload)
except ValidationError: except ValidationError:
error = ErrorModel(**json_payload) raise parse_error(resp.status, json_payload)
raise error.to_exception() yield response
yield response.token
def __del__(self):
self.session.close()
async def main(): class APIInferenceAsyncClient(AsyncClient):
client = AsyncClient("bigscience/bloomz") def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
async for token in client.generate_stream(): # Text Generation Inference client only supports a subset of the available hub models
if model_id not in SUPPORTED_MODELS:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
if __name__ == "__main__":
async def main():
client = APIInferenceAsyncClient(
"bigscience/bloomz", token="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"
)
async for token in client.generate_stream("test"):
print(token) print(token)
print(await client.generate()) print(await client.generate("test"))
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@ -0,0 +1,93 @@
from typing import Dict
# Text Generation Inference Errors
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
# API Inference Errors
class BadRequestError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardNotReadyError(Exception):
def __init__(self, message: str):
super().__init__(message)
class TimeoutError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotFoundError(Exception):
def __init__(self, message: str):
super().__init__(message)
class RateLimitExceededError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotSupportedError(Exception):
def __init__(self, model_id: str):
message = (
f"Model `{model_id}` is not available for inference with this client. \n"
"Use `huggingface_hub.inference_api.InferenceApi` instead."
)
super(NotSupportedError, self).__init__(message)
# Unknown error
class UnknownError(Exception):
def __init__(self, message: str):
super().__init__(message)
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
# Try to parse a Text Generation Inference error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
# Try to parse a APIInference error
if status_code == 400:
return BadRequestError(message)
if status_code == 403 or status_code == 424:
return ShardNotReadyError(message)
if status_code == 504:
return TimeoutError(message)
if status_code == 404:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
# Fallback to an unknown error
return UnknownError(message)

View File

@ -1,59 +1,8 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List, Type from typing import Optional, List
from text_generation_inference.errors import ValidationError
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class InferenceAPIError(Exception):
def __init__(self, message: str):
super(InferenceAPIError, self).__init__(message)
class ErrorType(str, Enum):
generation = "generation"
incomplete_generation = "incomplete_generation"
overloaded = "overloaded"
validation = "validation"
def to_exception_type(self) -> Type[Exception]:
if self == ErrorType.generation:
return GenerationError
if self == ErrorType.incomplete_generation:
return IncompleteGenerationError
if self == ErrorType.overloaded:
return OverloadedError
if self == ErrorType.validation:
return ValidationError
raise ValueError("Unknown error")
class ErrorModel(BaseModel):
error_type: Optional[ErrorType]
error: str
def to_exception(self) -> Exception:
if self.error_type is not None:
return self.error_type.to_exception_type()(self.error)
return InferenceAPIError(self.error)
class Parameters(BaseModel): class Parameters(BaseModel):
@ -61,12 +10,13 @@ class Parameters(BaseModel):
max_new_tokens: int = 20 max_new_tokens: int = 20
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
return_full_text: bool = False return_full_text: bool = False
stop: List[str]
seed: Optional[int] seed: Optional[int]
stop: Optional[List[str]]
temperature: Optional[float] temperature: Optional[float]
top_k: Optional[int] top_k: Optional[int]
top_p: Optional[float] top_p: Optional[float]
watermark: bool = False watermark: bool = False
details: bool = False
@validator("seed") @validator("seed")
def valid_seed(cls, v): def valid_seed(cls, v):
@ -93,8 +43,16 @@ class Parameters(BaseModel):
return v return v
class Response(BaseModel): class Request(BaseModel):
generated_text: str inputs: str
parameters: Parameters
stream: bool = False
class PrefillToken(BaseModel):
id: int
text: str
logprob: Optional[float]
class Token(BaseModel): class Token(BaseModel):
@ -104,5 +62,32 @@ class Token(BaseModel):
special: bool special: bool
class FinishReason(Enum):
Length = "length"
EndOfSequenceToken = "eos_token"
StopSequence = "stop_sequence"
class Details(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
prefill: List[PrefillToken]
tokens: List[Token]
class StreamDetails(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
class Response(BaseModel):
generated_text: str
details: Details
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
token: Token token: Token
generated_text: Optional[str]
details: Optional[StreamDetails]

View File

@ -152,8 +152,8 @@ pub(crate) struct Details {
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)] #[schema(example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Option<Vec<PrefillToken>>, pub prefill: Vec<PrefillToken>,
pub tokens: Option<Vec<Token>>, pub tokens: Vec<Token>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]

View File

@ -133,8 +133,8 @@ async fn generate(
true => Some(Details { true => Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason), finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill), prefill: response.prefill,
tokens: Some(response.tokens), tokens: response.tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
}), }),
false => None, false => None,