ensure aiohttp session exists

This commit is contained in:
Sabidao 2024-04-21 18:10:28 +03:00 committed by GitHub
parent 3116fb5113
commit 75f954df6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 45 deletions

5
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"githubPullRequests.ignoredPullRequestBranches": [
"main"
]
}

View File

@ -1,4 +1,5 @@
import json import json
import aiohttp
import requests import requests
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
@ -449,11 +450,16 @@ class AsyncClient:
self.headers = headers self.headers = headers
self.cookies = cookies self.cookies = cookies
self.timeout = ClientTimeout(timeout) self.timeout = ClientTimeout(timeout)
self.session: ClientSession | None = None
async def ensure_session(self):
if self.session or self.session.closed:
self.session = ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
)
async def __aenter__(self): async def __aenter__(self):
self.session = ClientSession( await self.ensure_session()
headers=self.headers, cookies=self.cookies, timeout=self.timeout
)
return self return self
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
@ -518,6 +524,8 @@ class AsyncClient:
The tool to use The tool to use
""" """
await self.ensure_session()
request = ChatRequest( request = ChatRequest(
model="tgi", model="tgi",
messages=messages, messages=messages,
@ -542,16 +550,13 @@ class AsyncClient:
return self._chat_stream_response(request) return self._chat_stream_response(request)
async def _chat_single_response(self, request): async def _chat_single_response(self, request):
async with ClientSession( async with self.session.post(
headers=self.headers, cookies=self.cookies, timeout=self.timeout f"{self.base_url}/v1/chat/completions", json=request.dict()
) as session: ) as resp:
async with session.post( payload = await resp.json()
f"{self.base_url}/v1/chat/completions", json=request.dict() if resp.status != 200:
) as resp: raise parse_error(resp.status, payload)
payload = await resp.json() return ChatComplete(**payload)
if resp.status != 200:
raise parse_error(resp.status, payload)
return ChatComplete(**payload)
async def _chat_stream_response(self, request): async def _chat_stream_response(self, request):
async with ClientSession( async with ClientSession(
@ -643,7 +648,7 @@ class AsyncClient:
Returns: Returns:
Response: generated response Response: generated response
""" """
await self.ensure_session()
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
@ -667,15 +672,12 @@ class AsyncClient:
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession( async with self.session.post(self.base_url, json=request.dict()) as resp:
headers=self.headers, cookies=self.cookies, timeout=self.timeout payload = await resp.json()
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200: if resp.status != 200:
raise parse_error(resp.status, payload) raise parse_error(resp.status, payload)
return Response(**payload[0]) return Response(**payload[0])
async def generate_stream( async def generate_stream(
self, self,
@ -743,6 +745,8 @@ class AsyncClient:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens
""" """
# Validate parameters # Validate parameters
await self.ensure_session()
parameters = Parameters( parameters = Parameters(
best_of=None, best_of=None,
details=True, details=True,
@ -765,29 +769,26 @@ class AsyncClient:
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession( async with self.session.post(self.base_url, json=request.dict()) as resp:
headers=self.headers, cookies=self.cookies, timeout=self.timeout if resp.status != 200:
) as session: raise parse_error(resp.status, await resp.json())
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
# Parse ServerSentEvents # Parse ServerSentEvents
async for byte_payload in resp.content: async for byte_payload in resp.content:
# Skip line # Skip line
if byte_payload == b"\n": if byte_payload == b"\n":
continue continue
payload = byte_payload.decode("utf-8") payload = byte_payload.decode("utf-8")
# Event data # Event data
if payload.startswith("data:"): if payload.startswith("data:"):
# Decode payload # Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload # Parse payload
try: try:
response = StreamResponse(**json_payload) response = StreamResponse(**json_payload)
except ValidationError: except ValidationError:
# If we failed to parse the payload, then it is an error payload # If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload) raise parse_error(resp.status, json_payload)
yield response yield response