From 75f954df6ce4eb07cdb0460047ed618da50b677d Mon Sep 17 00:00:00 2001 From: Sabidao Date: Sun, 21 Apr 2024 18:10:28 +0300 Subject: [PATCH] ensure aiohttp session exists --- .vscode/settings.json | 5 ++ clients/python/text_generation/client.py | 91 ++++++++++++------------ 2 files changed, 51 insertions(+), 45 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..b242572e --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "githubPullRequests.ignoredPullRequestBranches": [ + "main" + ] +} \ No newline at end of file diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index d56b2ad2..9e2bba0d 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -1,4 +1,5 @@ import json +import aiohttp import requests from aiohttp import ClientSession, ClientTimeout @@ -449,11 +450,16 @@ class AsyncClient: self.headers = headers self.cookies = cookies self.timeout = ClientTimeout(timeout) + self.session: ClientSession | None = None + + async def ensure_session(self): + if self.session or self.session.closed: + self.session = ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) async def __aenter__(self): - self.session = ClientSession( - headers=self.headers, cookies=self.cookies, timeout=self.timeout - ) + await self.ensure_session() return self async def __aexit__(self, exc_type, exc, tb): @@ -518,6 +524,8 @@ class AsyncClient: The tool to use """ + await self.ensure_session() + request = ChatRequest( model="tgi", messages=messages, @@ -542,16 +550,13 @@ class AsyncClient: return self._chat_stream_response(request) async def _chat_single_response(self, request): - async with ClientSession( - headers=self.headers, cookies=self.cookies, timeout=self.timeout - ) as session: - async with session.post( - f"{self.base_url}/v1/chat/completions", json=request.dict() - ) as resp: - payload = await resp.json() - if resp.status != 200: - raise parse_error(resp.status, payload) - return ChatComplete(**payload) + async with self.session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return ChatComplete(**payload) async def _chat_stream_response(self, request): async with ClientSession( @@ -643,7 +648,7 @@ class AsyncClient: Returns: Response: generated response """ - + await self.ensure_session() # Validate parameters parameters = Parameters( best_of=best_of, @@ -667,15 +672,12 @@ class AsyncClient: ) request = Request(inputs=prompt, stream=False, parameters=parameters) - async with ClientSession( - headers=self.headers, cookies=self.cookies, timeout=self.timeout - ) as session: - async with session.post(self.base_url, json=request.dict()) as resp: - payload = await resp.json() + async with self.session.post(self.base_url, json=request.dict()) as resp: + payload = await resp.json() - if resp.status != 200: - raise parse_error(resp.status, payload) - return Response(**payload[0]) + if resp.status != 200: + raise parse_error(resp.status, payload) + return Response(**payload[0]) async def generate_stream( self, @@ -743,6 +745,8 @@ class AsyncClient: AsyncIterator[StreamResponse]: stream of generated tokens """ # Validate parameters + await self.ensure_session() + parameters = Parameters( best_of=None, details=True, @@ -765,29 +769,26 @@ class AsyncClient: ) request = Request(inputs=prompt, stream=True, parameters=parameters) - async with ClientSession( - headers=self.headers, cookies=self.cookies, timeout=self.timeout - ) as session: - async with session.post(self.base_url, json=request.dict()) as resp: - if resp.status != 200: - raise parse_error(resp.status, await resp.json()) + async with self.session.post(self.base_url, json=request.dict()) as resp: + if resp.status != 200: + raise parse_error(resp.status, await resp.json()) - # Parse ServerSentEvents - async for byte_payload in resp.content: - # Skip line - if byte_payload == b"\n": - continue + # Parse ServerSentEvents + async for byte_payload in resp.content: + # Skip line + if byte_payload == b"\n": + continue - payload = byte_payload.decode("utf-8") + payload = byte_payload.decode("utf-8") - # Event data - if payload.startswith("data:"): - # Decode payload - json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) - # Parse payload - try: - response = StreamResponse(**json_payload) - except ValidationError: - # If we failed to parse the payload, then it is an error payload - raise parse_error(resp.status, json_payload) - yield response + # Event data + if payload.startswith("data:"): + # Decode payload + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + # Parse payload + try: + response = StreamResponse(**json_payload) + except ValidationError: + # If we failed to parse the payload, then it is an error payload + raise parse_error(resp.status, json_payload) + yield response