feat: add async context manager for AsyncClient

This commit is contained in:
Sabidao 2024-04-08 17:16:19 +03:00 committed by GitHub
parent ff42d33e99
commit 635701ca29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 0 deletions

View File

@ -148,3 +148,14 @@ async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_header
with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass
@pytest.mark.asyncio
async def test_async_client_context_manager(flan_t5_xxl_url, hf_headers):
async with AsyncClient(flan_t5_xxl_url, hf_headers) as client:
# Perform actions with the client here
response = await client.generate("Test input")
assert response is not None
async for chunk in client.generate_stream("Test input"):
assert chunk is not None

View File

@ -450,6 +450,15 @@ class AsyncClient:
self.cookies = cookies
self.timeout = ClientTimeout(timeout)
async def __aenter__(self):
self.session = ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
)
return self
async def __aexit__(self, exc_type, exc, tb):
await self.session.close()
async def chat(
self,
messages: List[Message],