diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index d53c2a4d..8972dfd1 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -87,6 +87,19 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers): ) +@pytest.mark.asyncio +async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): + client = AsyncClient(flan_t5_xxl_url, hf_headers) + response = await client.generate( + "test", max_new_tokens=1, best_of=2, do_sample=True + ) + + assert response.details.seed is not None + assert response.details.best_of_sequences is not None + assert len(response.details.best_of_sequences) == 1 + assert response.details.best_of_sequences[0].seed is not None + + @pytest.mark.asyncio async def test_generate_async_not_found(fake_url, hf_headers): client = AsyncClient(fake_url, hf_headers)