From 6d1ad06471c2da5870d963a3ac21a39b5919668a Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 7 Mar 2023 12:53:03 +0100 Subject: [PATCH] wip --- clients/python/README.md | 0 .../text_generation_inference/__init__.py | 1 + .../text_generation_inference/async_client.py | 167 +++++++++++++----- .../text_generation_inference/errors.py | 93 ++++++++++ .../python/text_generation_inference/types.py | 97 +++++----- router/src/lib.rs | 4 +- router/src/server.rs | 4 +- 7 files changed, 262 insertions(+), 104 deletions(-) create mode 100644 clients/python/README.md create mode 100644 clients/python/text_generation_inference/errors.py diff --git a/clients/python/README.md b/clients/python/README.md new file mode 100644 index 00000000..e69de29b diff --git a/clients/python/text_generation_inference/__init__.py b/clients/python/text_generation_inference/__init__.py index e69de29b..874227e2 100644 --- a/clients/python/text_generation_inference/__init__.py +++ b/clients/python/text_generation_inference/__init__.py @@ -0,0 +1 @@ +from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient diff --git a/clients/python/text_generation_inference/async_client.py b/clients/python/text_generation_inference/async_client.py index 225d55ad..db106455 100644 --- a/clients/python/text_generation_inference/async_client.py +++ b/clients/python/text_generation_inference/async_client.py @@ -1,63 +1,142 @@ import json +import os from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Optional +from typing import Dict, Optional, List, AsyncIterator -from text_generation_inference.types import StreamResponse, ErrorModel, Response +from text_generation_inference import SUPPORTED_MODELS +from text_generation_inference.types import ( + StreamResponse, + Response, + Request, + Parameters, +) +from text_generation_inference.errors import parse_error, NotSupportedError + +INFERENCE_ENDPOINT = os.environ.get( + "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" +) class AsyncClient: + def __init__( + self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10 + ): + self.base_url = base_url + self.headers = headers + self.timeout = ClientTimeout(timeout * 60) + + async def generate( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + watermark: bool = False, + ) -> Response: + parameters = Parameters( + details=True, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop if stop is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + watermark=watermark, + ) + request = Request(inputs=prompt, stream=False, parameters=parameters) + + async with ClientSession(headers=self.headers, timeout=self.timeout) as session: + async with session.post(self.base_url, json=request.dict()) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return Response(**payload[0]) + + async def generate_stream( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + watermark: bool = False, + ) -> AsyncIterator[StreamResponse]: + parameters = Parameters( + details=True, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop if stop is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + watermark=watermark, + ) + request = Request(inputs=prompt, stream=True, parameters=parameters) + + async with ClientSession(headers=self.headers, timeout=self.timeout) as session: + async with session.post(self.base_url, json=request.dict()) as resp: + if resp.status != 200: + raise parse_error(resp.status, await resp.json()) + + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + try: + response = StreamResponse(**json_payload) + except ValidationError: + raise parse_error(resp.status, json_payload) + yield response + + +class APIInferenceAsyncClient(AsyncClient): def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10): + # Text Generation Inference client only supports a subset of the available hub models + if model_id not in SUPPORTED_MODELS: + raise NotSupportedError(model_id) + headers = {} if token is not None: headers = {"Authorization": f"Bearer {token}"} - self.model_id = model_id + base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}" - self.session = ClientSession(headers=headers, timeout=ClientTimeout(timeout * 60)) - - async def generate(self): - async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}", - json={"inputs": "test", "stream": True}) as resp: - if resp.status != 200: - error = ErrorModel(**await resp.json()) - raise error.to_exception() - return Response(**await resp.json()) - - async def generate_stream(self): - async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}", - json={"inputs": "test", "stream": True}) as resp: - if resp.status != 200: - error = ErrorModel(**await resp.json()) - raise error.to_exception() - - async for byte_payload in resp.content: - if byte_payload == b"\n": - continue - - payload = byte_payload.decode("utf-8") - - if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) - try: - response = StreamResponse(**json_payload) - except ValidationError: - error = ErrorModel(**json_payload) - raise error.to_exception() - yield response.token - - def __del__(self): - self.session.close() + super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout) -async def main(): - client = AsyncClient("bigscience/bloomz") - async for token in client.generate_stream(): - print(token) +if __name__ == "__main__": + async def main(): + client = APIInferenceAsyncClient( + "bigscience/bloomz", token="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" + ) + async for token in client.generate_stream("test"): + print(token) - print(await client.generate()) + print(await client.generate("test")) -import asyncio + import asyncio -asyncio.run(main()) + asyncio.run(main()) diff --git a/clients/python/text_generation_inference/errors.py b/clients/python/text_generation_inference/errors.py new file mode 100644 index 00000000..c4289ce6 --- /dev/null +++ b/clients/python/text_generation_inference/errors.py @@ -0,0 +1,93 @@ +from typing import Dict + + +# Text Generation Inference Errors +class ValidationError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class GenerationError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class OverloadedError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class IncompleteGenerationError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +# API Inference Errors +class BadRequestError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class ShardNotReadyError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class TimeoutError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class NotFoundError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class RateLimitExceededError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class NotSupportedError(Exception): + def __init__(self, model_id: str): + message = ( + f"Model `{model_id}` is not available for inference with this client. \n" + "Use `huggingface_hub.inference_api.InferenceApi` instead." + ) + super(NotSupportedError, self).__init__(message) + + +# Unknown error +class UnknownError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +def parse_error(status_code: int, payload: Dict[str, str]) -> Exception: + # Try to parse a Text Generation Inference error + message = payload["error"] + if "error_type" in payload: + error_type = payload["error_type"] + if error_type == "generation": + return GenerationError(message) + if error_type == "incomplete_generation": + return IncompleteGenerationError(message) + if error_type == "overloaded": + return OverloadedError(message) + if error_type == "validation": + return ValidationError(message) + + # Try to parse a APIInference error + if status_code == 400: + return BadRequestError(message) + if status_code == 403 or status_code == 424: + return ShardNotReadyError(message) + if status_code == 504: + return TimeoutError(message) + if status_code == 404: + return NotFoundError(message) + if status_code == 429: + return RateLimitExceededError(message) + + # Fallback to an unknown error + return UnknownError(message) diff --git a/clients/python/text_generation_inference/types.py b/clients/python/text_generation_inference/types.py index a9820a39..4ce65f29 100644 --- a/clients/python/text_generation_inference/types.py +++ b/clients/python/text_generation_inference/types.py @@ -1,59 +1,8 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List, Type +from typing import Optional, List - -class ValidationError(Exception): - def __init__(self, message: str): - super().__init__(message) - - -class GenerationError(Exception): - def __init__(self, message: str): - super().__init__(message) - - -class OverloadedError(Exception): - def __init__(self, message: str): - super().__init__(message) - - -class IncompleteGenerationError(Exception): - def __init__(self, message: str): - super().__init__(message) - - -class InferenceAPIError(Exception): - def __init__(self, message: str): - super(InferenceAPIError, self).__init__(message) - - -class ErrorType(str, Enum): - generation = "generation" - incomplete_generation = "incomplete_generation" - overloaded = "overloaded" - validation = "validation" - - def to_exception_type(self) -> Type[Exception]: - if self == ErrorType.generation: - return GenerationError - if self == ErrorType.incomplete_generation: - return IncompleteGenerationError - if self == ErrorType.overloaded: - return OverloadedError - if self == ErrorType.validation: - return ValidationError - raise ValueError("Unknown error") - - -class ErrorModel(BaseModel): - error_type: Optional[ErrorType] - error: str - - def to_exception(self) -> Exception: - if self.error_type is not None: - return self.error_type.to_exception_type()(self.error) - return InferenceAPIError(self.error) +from text_generation_inference.errors import ValidationError class Parameters(BaseModel): @@ -61,12 +10,13 @@ class Parameters(BaseModel): max_new_tokens: int = 20 repetition_penalty: Optional[float] = None return_full_text: bool = False + stop: List[str] seed: Optional[int] - stop: Optional[List[str]] temperature: Optional[float] top_k: Optional[int] top_p: Optional[float] watermark: bool = False + details: bool = False @validator("seed") def valid_seed(cls, v): @@ -93,8 +43,16 @@ class Parameters(BaseModel): return v -class Response(BaseModel): - generated_text: str +class Request(BaseModel): + inputs: str + parameters: Parameters + stream: bool = False + + +class PrefillToken(BaseModel): + id: int + text: str + logprob: Optional[float] class Token(BaseModel): @@ -104,5 +62,32 @@ class Token(BaseModel): special: bool +class FinishReason(Enum): + Length = "length" + EndOfSequenceToken = "eos_token" + StopSequence = "stop_sequence" + + +class Details(BaseModel): + finish_reason: FinishReason + generated_tokens: int + seed: Optional[int] + prefill: List[PrefillToken] + tokens: List[Token] + + +class StreamDetails(BaseModel): + finish_reason: FinishReason + generated_tokens: int + seed: Optional[int] + + +class Response(BaseModel): + generated_text: str + details: Details + + class StreamResponse(BaseModel): token: Token + generated_text: Optional[str] + details: Optional[StreamDetails] diff --git a/router/src/lib.rs b/router/src/lib.rs index b3f078a4..af1be6d9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -152,8 +152,8 @@ pub(crate) struct Details { pub generated_tokens: u32, #[schema(example = 42)] pub seed: Option, - pub prefill: Option>, - pub tokens: Option>, + pub prefill: Vec, + pub tokens: Vec, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 0ccb470c..b8b6b440 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -133,8 +133,8 @@ async fn generate( true => Some(Details { finish_reason: FinishReason::from(response.generated_text.finish_reason), generated_tokens: response.generated_text.generated_tokens, - prefill: Some(response.prefill), - tokens: Some(response.tokens), + prefill: response.prefill, + tokens: response.tokens, seed: response.generated_text.seed, }), false => None,