diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 76ac80d3..c998de41 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -21,6 +21,16 @@ def test_generate(flan_t5_xxl_url, hf_headers): ) +def test_generate_best_of(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) + response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True) + + assert response.details.seed is not None + assert response.details.best_of_sequences is not None + assert len(response.details.best_of_sequences) == 1 + assert response.details.best_of_sequences[0].seed is not None + + def test_generate_not_found(fake_url, hf_headers): client = Client(fake_url, hf_headers) with pytest.raises(NotFoundError):