mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: add async context manager for AsyncClient
This commit is contained in:
parent
ff42d33e99
commit
635701ca29
@ -148,3 +148,14 @@ async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_header
|
||||
with pytest.raises(ValidationError):
|
||||
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_context_manager(flan_t5_xxl_url, hf_headers):
|
||||
async with AsyncClient(flan_t5_xxl_url, hf_headers) as client:
|
||||
# Perform actions with the client here
|
||||
response = await client.generate("Test input")
|
||||
assert response is not None
|
||||
|
||||
async for chunk in client.generate_stream("Test input"):
|
||||
assert chunk is not None
|
@ -450,6 +450,15 @@ class AsyncClient:
|
||||
self.cookies = cookies
|
||||
self.timeout = ClientTimeout(timeout)
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.session.close()
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
|
Loading…
Reference in New Issue
Block a user