mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
wip
This commit is contained in:
parent
0a27d56634
commit
6d1ad06471
0
clients/python/README.md
Normal file
0
clients/python/README.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient
|
@ -1,35 +1,101 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from typing import Optional
|
from typing import Dict, Optional, List, AsyncIterator
|
||||||
|
|
||||||
from text_generation_inference.types import StreamResponse, ErrorModel, Response
|
from text_generation_inference import SUPPORTED_MODELS
|
||||||
|
from text_generation_inference.types import (
|
||||||
|
StreamResponse,
|
||||||
|
Response,
|
||||||
|
Request,
|
||||||
|
Parameters,
|
||||||
|
)
|
||||||
|
from text_generation_inference.errors import parse_error, NotSupportedError
|
||||||
|
|
||||||
|
INFERENCE_ENDPOINT = os.environ.get(
|
||||||
|
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncClient:
|
class AsyncClient:
|
||||||
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
def __init__(
|
||||||
headers = {}
|
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
|
||||||
if token is not None:
|
):
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
self.base_url = base_url
|
||||||
self.model_id = model_id
|
self.headers = headers
|
||||||
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
|
||||||
self.session = ClientSession(headers=headers, timeout=ClientTimeout(timeout * 60))
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> Response:
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop if stop is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
async def generate(self):
|
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||||
async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}",
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
json={"inputs": "test", "stream": True}) as resp:
|
payload = await resp.json()
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
error = ErrorModel(**await resp.json())
|
raise parse_error(resp.status, payload)
|
||||||
raise error.to_exception()
|
return Response(**payload[0])
|
||||||
return Response(**await resp.json())
|
|
||||||
|
|
||||||
async def generate_stream(self):
|
async def generate_stream(
|
||||||
async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}",
|
self,
|
||||||
json={"inputs": "test", "stream": True}) as resp:
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> AsyncIterator[StreamResponse]:
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop if stop is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||||
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
error = ErrorModel(**await resp.json())
|
raise parse_error(resp.status, await resp.json())
|
||||||
raise error.to_exception()
|
|
||||||
|
|
||||||
async for byte_payload in resp.content:
|
async for byte_payload in resp.content:
|
||||||
if byte_payload == b"\n":
|
if byte_payload == b"\n":
|
||||||
@ -42,22 +108,35 @@ class AsyncClient:
|
|||||||
try:
|
try:
|
||||||
response = StreamResponse(**json_payload)
|
response = StreamResponse(**json_payload)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
error = ErrorModel(**json_payload)
|
raise parse_error(resp.status, json_payload)
|
||||||
raise error.to_exception()
|
yield response
|
||||||
yield response.token
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.session.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
class APIInferenceAsyncClient(AsyncClient):
|
||||||
client = AsyncClient("bigscience/bloomz")
|
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||||
async for token in client.generate_stream():
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
if model_id not in SUPPORTED_MODELS:
|
||||||
|
raise NotSupportedError(model_id)
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if token is not None:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
|
||||||
|
|
||||||
|
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
async def main():
|
||||||
|
client = APIInferenceAsyncClient(
|
||||||
|
"bigscience/bloomz", token="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"
|
||||||
|
)
|
||||||
|
async for token in client.generate_stream("test"):
|
||||||
print(token)
|
print(token)
|
||||||
|
|
||||||
print(await client.generate())
|
print(await client.generate("test"))
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
93
clients/python/text_generation_inference/errors.py
Normal file
93
clients/python/text_generation_inference/errors.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
# Text Generation Inference Errors
|
||||||
|
class ValidationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class OverloadedError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class IncompleteGenerationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
# API Inference Errors
|
||||||
|
class BadRequestError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardNotReadyError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitExceededError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedError(Exception):
|
||||||
|
def __init__(self, model_id: str):
|
||||||
|
message = (
|
||||||
|
f"Model `{model_id}` is not available for inference with this client. \n"
|
||||||
|
"Use `huggingface_hub.inference_api.InferenceApi` instead."
|
||||||
|
)
|
||||||
|
super(NotSupportedError, self).__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
# Unknown error
|
||||||
|
class UnknownError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
|
||||||
|
# Try to parse a Text Generation Inference error
|
||||||
|
message = payload["error"]
|
||||||
|
if "error_type" in payload:
|
||||||
|
error_type = payload["error_type"]
|
||||||
|
if error_type == "generation":
|
||||||
|
return GenerationError(message)
|
||||||
|
if error_type == "incomplete_generation":
|
||||||
|
return IncompleteGenerationError(message)
|
||||||
|
if error_type == "overloaded":
|
||||||
|
return OverloadedError(message)
|
||||||
|
if error_type == "validation":
|
||||||
|
return ValidationError(message)
|
||||||
|
|
||||||
|
# Try to parse a APIInference error
|
||||||
|
if status_code == 400:
|
||||||
|
return BadRequestError(message)
|
||||||
|
if status_code == 403 or status_code == 424:
|
||||||
|
return ShardNotReadyError(message)
|
||||||
|
if status_code == 504:
|
||||||
|
return TimeoutError(message)
|
||||||
|
if status_code == 404:
|
||||||
|
return NotFoundError(message)
|
||||||
|
if status_code == 429:
|
||||||
|
return RateLimitExceededError(message)
|
||||||
|
|
||||||
|
# Fallback to an unknown error
|
||||||
|
return UnknownError(message)
|
@ -1,59 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from typing import Optional, List, Type
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from text_generation_inference.errors import ValidationError
|
||||||
class ValidationError(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationError(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class OverloadedError(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class IncompleteGenerationError(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceAPIError(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
super(InferenceAPIError, self).__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorType(str, Enum):
|
|
||||||
generation = "generation"
|
|
||||||
incomplete_generation = "incomplete_generation"
|
|
||||||
overloaded = "overloaded"
|
|
||||||
validation = "validation"
|
|
||||||
|
|
||||||
def to_exception_type(self) -> Type[Exception]:
|
|
||||||
if self == ErrorType.generation:
|
|
||||||
return GenerationError
|
|
||||||
if self == ErrorType.incomplete_generation:
|
|
||||||
return IncompleteGenerationError
|
|
||||||
if self == ErrorType.overloaded:
|
|
||||||
return OverloadedError
|
|
||||||
if self == ErrorType.validation:
|
|
||||||
return ValidationError
|
|
||||||
raise ValueError("Unknown error")
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorModel(BaseModel):
|
|
||||||
error_type: Optional[ErrorType]
|
|
||||||
error: str
|
|
||||||
|
|
||||||
def to_exception(self) -> Exception:
|
|
||||||
if self.error_type is not None:
|
|
||||||
return self.error_type.to_exception_type()(self.error)
|
|
||||||
return InferenceAPIError(self.error)
|
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
@ -61,12 +10,13 @@ class Parameters(BaseModel):
|
|||||||
max_new_tokens: int = 20
|
max_new_tokens: int = 20
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
return_full_text: bool = False
|
return_full_text: bool = False
|
||||||
|
stop: List[str]
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
stop: Optional[List[str]]
|
|
||||||
temperature: Optional[float]
|
temperature: Optional[float]
|
||||||
top_k: Optional[int]
|
top_k: Optional[int]
|
||||||
top_p: Optional[float]
|
top_p: Optional[float]
|
||||||
watermark: bool = False
|
watermark: bool = False
|
||||||
|
details: bool = False
|
||||||
|
|
||||||
@validator("seed")
|
@validator("seed")
|
||||||
def valid_seed(cls, v):
|
def valid_seed(cls, v):
|
||||||
@ -93,8 +43,16 @@ class Parameters(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class Response(BaseModel):
|
class Request(BaseModel):
|
||||||
generated_text: str
|
inputs: str
|
||||||
|
parameters: Parameters
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillToken(BaseModel):
|
||||||
|
id: int
|
||||||
|
text: str
|
||||||
|
logprob: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
@ -104,5 +62,32 @@ class Token(BaseModel):
|
|||||||
special: bool
|
special: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FinishReason(Enum):
|
||||||
|
Length = "length"
|
||||||
|
EndOfSequenceToken = "eos_token"
|
||||||
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
class Details(BaseModel):
|
||||||
|
finish_reason: FinishReason
|
||||||
|
generated_tokens: int
|
||||||
|
seed: Optional[int]
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamDetails(BaseModel):
|
||||||
|
finish_reason: FinishReason
|
||||||
|
generated_tokens: int
|
||||||
|
seed: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class Response(BaseModel):
|
||||||
|
generated_text: str
|
||||||
|
details: Details
|
||||||
|
|
||||||
|
|
||||||
class StreamResponse(BaseModel):
|
class StreamResponse(BaseModel):
|
||||||
token: Token
|
token: Token
|
||||||
|
generated_text: Optional[str]
|
||||||
|
details: Optional[StreamDetails]
|
||||||
|
@ -152,8 +152,8 @@ pub(crate) struct Details {
|
|||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(example = 42)]
|
#[schema(example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Option<Vec<PrefillToken>>,
|
pub prefill: Vec<PrefillToken>,
|
||||||
pub tokens: Option<Vec<Token>>,
|
pub tokens: Vec<Token>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -133,8 +133,8 @@ async fn generate(
|
|||||||
true => Some(Details {
|
true => Some(Details {
|
||||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: Some(response.prefill),
|
prefill: response.prefill,
|
||||||
tokens: Some(response.tokens),
|
tokens: response.tokens,
|
||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
}),
|
}),
|
||||||
false => None,
|
false => None,
|
||||||
|
Loading…
Reference in New Issue
Block a user