From 0a27d56634d0a312b1d9b7abe61685b80bbe5084 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 7 Mar 2023 10:14:49 +0100 Subject: [PATCH] wip --- .../text_generation_inference/async_client.py | 64 ++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/clients/python/text_generation_inference/async_client.py b/clients/python/text_generation_inference/async_client.py index 4867c47a..225d55ad 100644 --- a/clients/python/text_generation_inference/async_client.py +++ b/clients/python/text_generation_inference/async_client.py @@ -1,3 +1,63 @@ -# class AsyncClient: +import json -from asyncio import \ No newline at end of file +from aiohttp import ClientSession, ClientTimeout +from pydantic import ValidationError +from typing import Optional + +from text_generation_inference.types import StreamResponse, ErrorModel, Response + + +class AsyncClient: + def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10): + headers = {} + if token is not None: + headers = {"Authorization": f"Bearer {token}"} + self.model_id = model_id + + self.session = ClientSession(headers=headers, timeout=ClientTimeout(timeout * 60)) + + async def generate(self): + async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}", + json={"inputs": "test", "stream": True}) as resp: + if resp.status != 200: + error = ErrorModel(**await resp.json()) + raise error.to_exception() + return Response(**await resp.json()) + + async def generate_stream(self): + async with self.session.post(f"https://api-inference.huggingface.co/models/{self.model_id}", + json={"inputs": "test", "stream": True}) as resp: + if resp.status != 200: + error = ErrorModel(**await resp.json()) + raise error.to_exception() + + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + try: + response = StreamResponse(**json_payload) + except ValidationError: + error = ErrorModel(**json_payload) + raise error.to_exception() + yield response.token + + def __del__(self): + self.session.close() + + +async def main(): + client = AsyncClient("bigscience/bloomz") + async for token in client.generate_stream(): + print(token) + + print(await client.generate()) + + +import asyncio + +asyncio.run(main())