mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-17 14:52:09 +00:00
fix(neuron): adapt entrypoint
This commit is contained in:
parent
3e977bde99
commit
5d2b159182
@ -159,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz
|
||||
# Final image
|
||||
FROM neuron
|
||||
|
||||
COPY backends/neuron/tgi_env.py /tgi_env.py
|
||||
COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
|
||||
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
135
backends/neuron/tgi_env.py → backends/neuron/server/text_generation_server/tgi_env.py
Executable file → Normal file
135
backends/neuron/tgi_env.py → backends/neuron/server/text_generation_server/tgi_env.py
Executable file → Normal file
@ -6,12 +6,11 @@ import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from huggingface_hub import constants
|
||||
|
||||
from optimum.neuron.modeling_decoder import get_available_cores
|
||||
from optimum.neuron.cache import get_hub_cached_entries
|
||||
from optimum.neuron.configuration_utils import NeuronConfig
|
||||
from optimum.neuron.utils.version_utils import get_neuronxcc_version
|
||||
from optimum.neuron.utils import map_torch_dtype
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -24,15 +23,9 @@ tgi_router_env_vars = [
|
||||
]
|
||||
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
|
||||
|
||||
env_config_peering = [
|
||||
("MAX_BATCH_SIZE", "batch_size"),
|
||||
("MAX_TOTAL_TOKENS", "sequence_length"),
|
||||
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
|
||||
("HF_NUM_CORES", "num_cores"),
|
||||
]
|
||||
|
||||
# By the end of this script all env var should be specified properly
|
||||
env_vars = tgi_server_env_vars + tgi_router_env_vars
|
||||
tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
|
||||
|
||||
available_cores = get_available_cores()
|
||||
neuronxcc_version = get_neuronxcc_version()
|
||||
@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
|
||||
|
||||
|
||||
def neuron_config_to_env(neuron_config):
|
||||
if isinstance(neuron_config, NeuronConfig):
|
||||
neuron_config = neuron_config.to_dict()
|
||||
with open(os.environ["ENV_FILEPATH"], "w") as f:
|
||||
for env_var, config_key in env_config_peering:
|
||||
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
|
||||
f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
|
||||
f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
|
||||
f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
|
||||
config_key = (
|
||||
"auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
|
||||
)
|
||||
auto_cast_type = neuron_config[config_key]
|
||||
f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
|
||||
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
|
||||
if not max_input_tokens:
|
||||
max_input_tokens = int(neuron_config["sequence_length"]) // 2
|
||||
@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config):
|
||||
|
||||
|
||||
def sort_neuron_configs(dictionary):
|
||||
return -dictionary["num_cores"], -dictionary["batch_size"]
|
||||
return -dictionary["tp_degree"], -dictionary["batch_size"]
|
||||
|
||||
|
||||
def lookup_compatible_cached_model(
|
||||
@ -119,7 +120,7 @@ def lookup_compatible_cached_model(
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# Reuse the same mechanic as the one in use to configure the tgi server part
|
||||
# The only difference here is that we stay as flexible as possible on the compatibility part
|
||||
entries = get_hub_cached_entries(model_id, "inference")
|
||||
entries = get_hub_cached_entries(model_id)
|
||||
|
||||
logger.debug(
|
||||
"Found %d cached entries for model %s, revision %s",
|
||||
@ -155,15 +156,15 @@ def lookup_compatible_cached_model(
|
||||
|
||||
|
||||
def check_env_and_neuron_config_compatibility(
|
||||
neuron_config: Dict[str, Any], check_compiler_version: bool
|
||||
neuron_config_dict: Dict[str, Any], check_compiler_version: bool
|
||||
) -> bool:
|
||||
logger.debug(
|
||||
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
|
||||
neuron_config,
|
||||
neuron_config_dict,
|
||||
)
|
||||
|
||||
# Local setup compat checks
|
||||
if neuron_config["num_cores"] > available_cores:
|
||||
if neuron_config_dict["tp_degree"] > available_cores:
|
||||
logger.debug(
|
||||
"Not enough neuron cores available to run the provided neuron config"
|
||||
)
|
||||
@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility(
|
||||
|
||||
if (
|
||||
check_compiler_version
|
||||
and neuron_config["compiler_version"] != neuronxcc_version
|
||||
and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
|
||||
):
|
||||
logger.debug(
|
||||
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
|
||||
neuronxcc_version,
|
||||
neuron_config["compiler_version"],
|
||||
neuron_config_dict["neuronxcc_version"],
|
||||
)
|
||||
return False
|
||||
|
||||
for env_var, config_key in env_config_peering:
|
||||
neuron_config_value = str(neuron_config[config_key])
|
||||
env_value = os.getenv(env_var, str(neuron_config_value))
|
||||
batch_size = os.getenv("MAX_BATCH_SIZE", None)
|
||||
if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
|
||||
logger.debug(
|
||||
"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
|
||||
os.getenv("MAX_BATCH_SIZE"),
|
||||
neuron_config_dict["batch_size"],
|
||||
)
|
||||
return False
|
||||
max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
|
||||
if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
|
||||
max_total_tokens
|
||||
):
|
||||
logger.debug(
|
||||
"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
|
||||
max_total_tokens,
|
||||
neuron_config_dict["sequence_length"],
|
||||
)
|
||||
return False
|
||||
num_cores = os.getenv("HF_NUM_CORES", None)
|
||||
if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
|
||||
logger.debug(
|
||||
"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
|
||||
num_cores,
|
||||
neuron_config_dict["tp_degree"],
|
||||
)
|
||||
return False
|
||||
auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
|
||||
if auto_cast_type is not None:
|
||||
config_key = (
|
||||
"auto_cast_type"
|
||||
if "auto_cast_type" in neuron_config_dict
|
||||
else "torch_dtype"
|
||||
)
|
||||
neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
|
||||
env_value = map_torch_dtype(auto_cast_type)
|
||||
if env_value != neuron_config_value:
|
||||
logger.debug(
|
||||
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
|
||||
env_var,
|
||||
config_key,
|
||||
"The provided auto cast type and the neuron config param differ (%s != %s)",
|
||||
env_value,
|
||||
neuron_config_value,
|
||||
)
|
||||
return False
|
||||
|
||||
max_input_tokens = int(
|
||||
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
|
||||
)
|
||||
if max_input_tokens > 0:
|
||||
sequence_length = neuron_config["sequence_length"]
|
||||
if hasattr(neuron_config_dict, "max_context_length"):
|
||||
sequence_length = neuron_config_dict["max_context_length"]
|
||||
else:
|
||||
sequence_length = neuron_config_dict["sequence_length"]
|
||||
if max_input_tokens >= sequence_length:
|
||||
logger.debug(
|
||||
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
|
||||
@ -211,48 +244,29 @@ def check_env_and_neuron_config_compatibility(
|
||||
|
||||
def get_env_dict() -> Dict[str, str]:
|
||||
d = {}
|
||||
for k in env_vars:
|
||||
for k in tgi_env_vars:
|
||||
d[k] = os.getenv(k)
|
||||
return d
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
This script determines proper default TGI env variables for the neuron precompiled models to
|
||||
work properly
|
||||
:return:
|
||||
"""
|
||||
args = parse_cmdline_and_set_env()
|
||||
|
||||
for env_var in env_vars:
|
||||
if not os.getenv(env_var):
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
"All env vars %s already set, skipping, user know what they are doing",
|
||||
env_vars,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
cache_dir = constants.HF_HUB_CACHE
|
||||
|
||||
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
|
||||
|
||||
def get_neuron_config_for_model(
|
||||
model_name_or_path: str, revision: Optional[str] = None
|
||||
) -> NeuronConfig:
|
||||
try:
|
||||
neuron_config = NeuronConfig.from_pretrained(
|
||||
args.model_id, revision=args.revision
|
||||
model_name_or_path, revision=revision
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
|
||||
args.model_id,
|
||||
args.revision,
|
||||
model_name_or_path,
|
||||
revision,
|
||||
e,
|
||||
)
|
||||
neuron_config = None
|
||||
if neuron_config is not None:
|
||||
compatible = check_env_and_neuron_config_compatibility(
|
||||
neuron_config, check_compiler_version=False
|
||||
neuron_config.to_dict(), check_compiler_version=False
|
||||
)
|
||||
if not compatible:
|
||||
env_dict = get_env_dict()
|
||||
@ -262,17 +276,6 @@ def main():
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
else:
|
||||
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
|
||||
neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
|
||||
|
||||
if not neuron_config:
|
||||
msg = (
|
||||
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
||||
).format(get_env_dict(), available_cores, neuronxcc_version)
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
neuron_config_to_env(neuron_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
return neuron_config
|
63
backends/neuron/tests/test_entry_point.py
Normal file
63
backends/neuron/tests/test_entry_point.py
Normal file
@ -0,0 +1,63 @@
|
||||
import os
|
||||
import pytest
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
|
||||
from optimum.neuron.utils import map_torch_dtype
|
||||
|
||||
from text_generation_server.tgi_env import (
|
||||
get_neuron_config_for_model,
|
||||
lookup_compatible_cached_model,
|
||||
neuron_config_to_env,
|
||||
)
|
||||
|
||||
|
||||
def test_get_neuron_config_for_model(neuron_model_config):
|
||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||
export_kwargs = neuron_model_config["export_kwargs"]
|
||||
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
|
||||
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
|
||||
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
|
||||
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
|
||||
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
||||
assert neuron_config is not None
|
||||
assert neuron_config.batch_size == export_kwargs["batch_size"]
|
||||
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
|
||||
assert neuron_config.tp_degree == export_kwargs["num_cores"]
|
||||
if isinstance(neuron_config, NxDNeuronConfig):
|
||||
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
|
||||
export_kwargs["auto_cast_type"]
|
||||
)
|
||||
else:
|
||||
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
|
||||
export_kwargs["auto_cast_type"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
|
||||
def test_lookup_compatible_cached_model(model_id: str):
|
||||
neuron_config = lookup_compatible_cached_model(model_id, None)
|
||||
assert neuron_config is not None
|
||||
|
||||
|
||||
def test_neuron_config_to_env(neuron_model_config) -> None:
|
||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
|
||||
neuron_config_to_env(neuron_config)
|
||||
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
|
||||
env_content = env_file.read()
|
||||
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
|
||||
assert (
|
||||
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
|
||||
in env_content
|
||||
)
|
||||
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
|
||||
if hasattr(neuron_config, "torch_dtype"):
|
||||
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
|
||||
"."
|
||||
)[-1]
|
||||
else:
|
||||
auto_cast_type = neuron_config.auto_cast_type
|
||||
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content
|
@ -9,7 +9,7 @@ touch $ENV_FILEPATH
|
||||
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
${SCRIPT_DIR}/tgi_env.py $@
|
||||
${SCRIPT_DIR}/tgi_entry_point.py $@
|
||||
|
||||
source $ENV_FILEPATH
|
||||
|
||||
|
53
backends/neuron/tgi_entry_point.py
Executable file
53
backends/neuron/tgi_entry_point.py
Executable file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
from text_generation_server.tgi_env import (
|
||||
available_cores,
|
||||
get_env_dict,
|
||||
get_neuron_config_for_model,
|
||||
neuron_config_to_env,
|
||||
neuronxcc_version,
|
||||
parse_cmdline_and_set_env,
|
||||
tgi_env_vars,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
This script determines proper default TGI env variables for the neuron precompiled models to
|
||||
work properly
|
||||
:return:
|
||||
"""
|
||||
args = parse_cmdline_and_set_env()
|
||||
|
||||
for env_var in tgi_env_vars:
|
||||
if not os.getenv(env_var):
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
"All env vars %s already set, skipping, user know what they are doing",
|
||||
tgi_env_vars,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
|
||||
|
||||
if not neuron_config:
|
||||
msg = (
|
||||
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
||||
).format(get_env_dict(), available_cores, neuronxcc_version)
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
neuron_config_to_env(neuron_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user