diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 1e25e1b1..9a6f8b6c 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -148,3 +148,14 @@ async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_header with pytest.raises(ValidationError): async for _ in client.generate_stream("test", max_new_tokens=10_000): pass + + +@pytest.mark.asyncio +async def test_async_client_context_manager(flan_t5_xxl_url, hf_headers): + async with AsyncClient(flan_t5_xxl_url, hf_headers) as client: + # Perform actions with the client here + response = await client.generate("Test input") + assert response is not None + + async for chunk in client.generate_stream("Test input"): + assert chunk is not None \ No newline at end of file diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 95d23901..d56b2ad2 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -450,6 +450,15 @@ class AsyncClient: self.cookies = cookies self.timeout = ClientTimeout(timeout) + async def __aenter__(self): + self.session = ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.session.close() + async def chat( self, messages: List[Message],