text-generation-inference/backends/neuron/tests/integration/test_implicit_env.py

77 lines
2.4 KiB
Python
Raw Normal View History

import os
import pytest
from huggingface_hub.errors import ValidationError
@pytest.fixture(scope="module", params=["hub-neuron", "hub", "local-neuron"])
async def tgi_service(request, launcher, neuron_model_config):
"""Expose a TGI service corresponding to a model configuration
For each model configuration, the service will be started using the following
deployment options:
- from the hub original model (export parameters chosen after hub lookup),
- from the hub pre-exported neuron model,
- from a local path to the neuron model.
"""
# the tgi_env.py script will take care of setting these
for var in [
"MAX_BATCH_SIZE",
"MAX_INPUT_TOKENS",
"MAX_TOTAL_TOKENS",
"HF_NUM_CORES",
"HF_AUTO_CAST_TYPE",
]:
if var in os.environ:
del os.environ[var]
if request.param == "hub":
model_name_or_path = neuron_model_config["model_id"]
elif request.param == "hub-neuron":
model_name_or_path = neuron_model_config["neuron_model_id"]
else:
model_name_or_path = neuron_model_config["neuron_model_path"]
service_name = neuron_model_config["name"]
with launcher(service_name, model_name_or_path) as tgi_service:
await tgi_service.health(600)
yield tgi_service
@pytest.mark.asyncio
async def test_model_single_request(tgi_service):
# Just verify that the generation works, and nothing is raised, with several set of params
# No params
await tgi_service.client.text_generation(
"What is Deep Learning?",
)
response = await tgi_service.client.text_generation(
"How to cook beans ?",
max_new_tokens=17,
details=True,
decoder_input_details=True,
)
assert response.details.generated_tokens == 17
# check error
try:
await tgi_service.client.text_generation("What is Deep Learning?", max_new_tokens=170000)
except ValidationError:
pass
else:
raise AssertionError(
"The previous text generation request should have failed, "
"because too many tokens were requested, it succeeded"
)
# Sampling
await tgi_service.client.text_generation(
"What is Deep Learning?",
do_sample=True,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
max_new_tokens=128,
seed=42,
)