2025-04-07 14:55:03 +00:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
from typing import Dict, Any, Generator
|
|
|
|
|
|
|
|
import pytest
|
2025-04-10 07:46:59 +00:00
|
|
|
from test_generate import TEST_CONFIGS
|
2025-04-07 14:55:03 +00:00
|
|
|
|
|
|
|
UNKNOWN_CONFIGS = {
|
|
|
|
name: config
|
|
|
|
for name, config in TEST_CONFIGS.items()
|
|
|
|
if config["expected_greedy_output"] == "unknown"
|
|
|
|
or config["expected_batch_output"] == "unknown"
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
|
|
|
|
def test_config(request) -> Dict[str, Any]:
|
|
|
|
"""Fixture that provides model configurations for testing."""
|
|
|
|
test_config = UNKNOWN_CONFIGS[request.param]
|
|
|
|
test_config["test_name"] = request.param
|
|
|
|
return test_config
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def test_name(test_config):
|
|
|
|
yield test_config["test_name"]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def tgi_service(launcher, test_config, test_name) -> Generator:
|
|
|
|
"""Fixture that provides a TGI service for testing."""
|
|
|
|
with launcher(test_config["model_id"], test_name) as service:
|
|
|
|
yield service
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_capture_expected_outputs(tgi_service, test_config, test_name):
|
|
|
|
"""Test that captures expected outputs for models with unknown outputs."""
|
|
|
|
print(f"Testing {test_name} with {test_config['model_id']}")
|
|
|
|
|
|
|
|
# Wait for service to be ready
|
|
|
|
await tgi_service.health(1000)
|
|
|
|
client = tgi_service.client
|
|
|
|
|
|
|
|
# Test single request (greedy)
|
|
|
|
print("Testing single request...")
|
|
|
|
response = await client.generate(
|
|
|
|
test_config["input"],
|
|
|
|
max_new_tokens=32,
|
|
|
|
)
|
|
|
|
greedy_output = response.generated_text
|
|
|
|
|
|
|
|
# Test multiple requests (batch)
|
|
|
|
print("Testing batch requests...")
|
|
|
|
responses = []
|
|
|
|
for _ in range(4):
|
|
|
|
response = await client.generate(
|
|
|
|
test_config["input"],
|
|
|
|
max_new_tokens=32,
|
|
|
|
)
|
|
|
|
responses.append(response.generated_text)
|
|
|
|
|
|
|
|
# Store results in a JSON file
|
|
|
|
output_file = "server/integration-tests/expected_outputs.json"
|
|
|
|
results = {}
|
|
|
|
|
|
|
|
# Try to load existing results if file exists
|
|
|
|
if os.path.exists(output_file):
|
|
|
|
with open(output_file, "r") as f:
|
|
|
|
results = json.load(f)
|
|
|
|
|
|
|
|
# Update results for this model
|
|
|
|
results[test_name] = {
|
|
|
|
"model_id": test_config["model_id"],
|
|
|
|
"input": test_config["input"],
|
|
|
|
"greedy_output": greedy_output,
|
|
|
|
"batch_outputs": responses,
|
|
|
|
"args": test_config["args"],
|
|
|
|
}
|
|
|
|
|
|
|
|
# Save updated results
|
|
|
|
with open(output_file, "w") as f:
|
|
|
|
json.dump(results, f, indent=2)
|
|
|
|
|
|
|
|
print(f"\nResults for {test_name} saved to {output_file}")
|