mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 19:02:09 +00:00
64 lines
2.7 KiB
Python
64 lines
2.7 KiB
Python
import os
|
|
import pytest
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
|
|
from optimum.neuron.utils import map_torch_dtype
|
|
|
|
from text_generation_server.tgi_env import (
|
|
get_neuron_config_for_model,
|
|
lookup_compatible_cached_model,
|
|
neuron_config_to_env,
|
|
)
|
|
|
|
|
|
def test_get_neuron_config_for_model(neuron_model_config):
|
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
|
export_kwargs = neuron_model_config["export_kwargs"]
|
|
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
|
|
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
|
|
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
|
|
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
|
|
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
|
assert neuron_config is not None
|
|
assert neuron_config.batch_size == export_kwargs["batch_size"]
|
|
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
|
|
assert neuron_config.tp_degree == export_kwargs["num_cores"]
|
|
if isinstance(neuron_config, NxDNeuronConfig):
|
|
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
|
|
export_kwargs["auto_cast_type"]
|
|
)
|
|
else:
|
|
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
|
|
export_kwargs["auto_cast_type"]
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
|
|
def test_lookup_compatible_cached_model(model_id: str):
|
|
neuron_config = lookup_compatible_cached_model(model_id, None)
|
|
assert neuron_config is not None
|
|
|
|
|
|
def test_neuron_config_to_env(neuron_model_config) -> None:
|
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
|
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
|
with TemporaryDirectory() as temp_dir:
|
|
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
|
|
neuron_config_to_env(neuron_config)
|
|
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
|
|
env_content = env_file.read()
|
|
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
|
|
assert (
|
|
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
|
|
in env_content
|
|
)
|
|
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
|
|
if hasattr(neuron_config, "torch_dtype"):
|
|
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
|
|
"."
|
|
)[-1]
|
|
else:
|
|
auto_cast_type = neuron_config.auto_cast_type
|
|
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content
|