mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
ensure aiohttp session exists
This commit is contained in:
parent
3116fb5113
commit
75f954df6c
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"githubPullRequests.ignoredPullRequestBranches": [
|
||||||
|
"main"
|
||||||
|
]
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
@ -449,11 +450,16 @@ class AsyncClient:
|
|||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout)
|
self.timeout = ClientTimeout(timeout)
|
||||||
|
self.session: ClientSession | None = None
|
||||||
|
|
||||||
|
async def ensure_session(self):
|
||||||
|
if self.session or self.session.closed:
|
||||||
|
self.session = ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.session = ClientSession(
|
await self.ensure_session()
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
|
||||||
)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
@ -518,6 +524,8 @@ class AsyncClient:
|
|||||||
The tool to use
|
The tool to use
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
await self.ensure_session()
|
||||||
|
|
||||||
request = ChatRequest(
|
request = ChatRequest(
|
||||||
model="tgi",
|
model="tgi",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@ -542,16 +550,13 @@ class AsyncClient:
|
|||||||
return self._chat_stream_response(request)
|
return self._chat_stream_response(request)
|
||||||
|
|
||||||
async def _chat_single_response(self, request):
|
async def _chat_single_response(self, request):
|
||||||
async with ClientSession(
|
async with self.session.post(
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||||
) as session:
|
) as resp:
|
||||||
async with session.post(
|
payload = await resp.json()
|
||||||
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
if resp.status != 200:
|
||||||
) as resp:
|
raise parse_error(resp.status, payload)
|
||||||
payload = await resp.json()
|
return ChatComplete(**payload)
|
||||||
if resp.status != 200:
|
|
||||||
raise parse_error(resp.status, payload)
|
|
||||||
return ChatComplete(**payload)
|
|
||||||
|
|
||||||
async def _chat_stream_response(self, request):
|
async def _chat_stream_response(self, request):
|
||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
@ -643,7 +648,7 @@ class AsyncClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
"""
|
"""
|
||||||
|
await self.ensure_session()
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
@ -667,15 +672,12 @@ class AsyncClient:
|
|||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
async with ClientSession(
|
async with self.session.post(self.base_url, json=request.dict()) as resp:
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
payload = await resp.json()
|
||||||
) as session:
|
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
|
||||||
payload = await resp.json()
|
|
||||||
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise parse_error(resp.status, payload)
|
raise parse_error(resp.status, payload)
|
||||||
return Response(**payload[0])
|
return Response(**payload[0])
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
self,
|
self,
|
||||||
@ -743,6 +745,8 @@ class AsyncClient:
|
|||||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
"""
|
"""
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
|
await self.ensure_session()
|
||||||
|
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=None,
|
best_of=None,
|
||||||
details=True,
|
details=True,
|
||||||
@ -765,29 +769,26 @@ class AsyncClient:
|
|||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
async with ClientSession(
|
async with self.session.post(self.base_url, json=request.dict()) as resp:
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
if resp.status != 200:
|
||||||
) as session:
|
raise parse_error(resp.status, await resp.json())
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
raise parse_error(resp.status, await resp.json())
|
|
||||||
|
|
||||||
# Parse ServerSentEvents
|
# Parse ServerSentEvents
|
||||||
async for byte_payload in resp.content:
|
async for byte_payload in resp.content:
|
||||||
# Skip line
|
# Skip line
|
||||||
if byte_payload == b"\n":
|
if byte_payload == b"\n":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = byte_payload.decode("utf-8")
|
payload = byte_payload.decode("utf-8")
|
||||||
|
|
||||||
# Event data
|
# Event data
|
||||||
if payload.startswith("data:"):
|
if payload.startswith("data:"):
|
||||||
# Decode payload
|
# Decode payload
|
||||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||||
# Parse payload
|
# Parse payload
|
||||||
try:
|
try:
|
||||||
response = StreamResponse(**json_payload)
|
response = StreamResponse(**json_payload)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
# If we failed to parse the payload, then it is an error payload
|
# If we failed to parse the payload, then it is an error payload
|
||||||
raise parse_error(resp.status, json_payload)
|
raise parse_error(resp.status, json_payload)
|
||||||
yield response
|
yield response
|
||||||
|
Loading…
Reference in New Issue
Block a user