mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: add async context manager for AsyncClient
This commit is contained in:
parent
ff42d33e99
commit
635701ca29
@ -148,3 +148,14 @@ async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_header
|
|||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_client_context_manager(flan_t5_xxl_url, hf_headers):
|
||||||
|
async with AsyncClient(flan_t5_xxl_url, hf_headers) as client:
|
||||||
|
# Perform actions with the client here
|
||||||
|
response = await client.generate("Test input")
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
async for chunk in client.generate_stream("Test input"):
|
||||||
|
assert chunk is not None
|
@ -450,6 +450,15 @@ class AsyncClient:
|
|||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout)
|
self.timeout = ClientTimeout(timeout)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.session = ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
Loading…
Reference in New Issue
Block a user