mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import os
|
|
|
|
import pytest
|
|
from huggingface_hub.errors import ValidationError
|
|
|
|
|
|
@pytest.fixture(scope="module", params=["hub-neuron", "hub", "local-neuron"])
|
|
async def tgi_service(request, neuron_launcher, neuron_model_config):
|
|
"""Expose a TGI service corresponding to a model configuration
|
|
|
|
For each model configuration, the service will be started using the following
|
|
deployment options:
|
|
- from the hub original model (export parameters chosen after hub lookup),
|
|
- from the hub pre-exported neuron model,
|
|
- from a local path to the neuron model.
|
|
"""
|
|
# the tgi_env.py script will take care of setting these
|
|
for var in [
|
|
"MAX_BATCH_SIZE",
|
|
"MAX_INPUT_TOKENS",
|
|
"MAX_TOTAL_TOKENS",
|
|
"HF_NUM_CORES",
|
|
"HF_AUTO_CAST_TYPE",
|
|
]:
|
|
if var in os.environ:
|
|
del os.environ[var]
|
|
if request.param == "hub":
|
|
model_name_or_path = neuron_model_config["model_id"]
|
|
elif request.param == "hub-neuron":
|
|
model_name_or_path = neuron_model_config["neuron_model_id"]
|
|
else:
|
|
model_name_or_path = neuron_model_config["neuron_model_path"]
|
|
service_name = neuron_model_config["name"]
|
|
with neuron_launcher(service_name, model_name_or_path) as tgi_service:
|
|
await tgi_service.health(600)
|
|
yield tgi_service
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_model_single_request(tgi_service):
|
|
# Just verify that the generation works, and nothing is raised, with several set of params
|
|
|
|
# No params
|
|
await tgi_service.client.text_generation(
|
|
"What is Deep Learning?",
|
|
)
|
|
|
|
response = await tgi_service.client.text_generation(
|
|
"How to cook beans ?",
|
|
max_new_tokens=17,
|
|
details=True,
|
|
decoder_input_details=True,
|
|
)
|
|
assert response.details.generated_tokens == 17
|
|
|
|
# check error
|
|
try:
|
|
await tgi_service.client.text_generation("What is Deep Learning?", max_new_tokens=170000)
|
|
except ValidationError:
|
|
pass
|
|
else:
|
|
raise AssertionError(
|
|
"The previous text generation request should have failed, "
|
|
"because too many tokens were requested, it succeeded"
|
|
)
|
|
|
|
# Sampling
|
|
await tgi_service.client.text_generation(
|
|
"What is Deep Learning?",
|
|
do_sample=True,
|
|
top_k=50,
|
|
top_p=0.9,
|
|
repetition_penalty=1.2,
|
|
max_new_tokens=128,
|
|
seed=42,
|
|
)
|