text-generation-inference/backends/neuron/tgi_entry_point.py
David Corvoysier 79183d1647
Bump neuron SDK version (#3260)
* chore(neuron): bump version to 0.2.0

* refactor(neuron): use named parameters in inputs helpers

This allows to hide the differences between the two backends in terms of
input parameters.

* refactor(neuron): remove obsolete code paths

* fix(neuron): use neuron_config whenever possible

* fix(neuron): use new cache import path

* fix(neuron): neuron config is not stored in config anymore

* fix(nxd): adapt model retrieval to new APIs

* fix(generator): emulate greedy in sampling parameters

When on-device sampling is enabled, we need to emulate the greedy
behaviour using top-k=1, top-p=1, temperature=1.

* test(neuron): update models and expectations

* feat(neuron): support on-device sampling

* fix(neuron): adapt entrypoint

* tests(neuron): remove obsolete models

* fix(neuron): adjust test expectations for llama on nxd
2025-06-10 17:56:25 +02:00

54 lines
1.2 KiB
Python
Executable File

#!/usr/bin/env python
import logging
import os
import sys
from text_generation_server.tgi_env import (
available_cores,
get_env_dict,
get_neuron_config_for_model,
neuron_config_to_env,
neuronxcc_version,
parse_cmdline_and_set_env,
tgi_env_vars,
)
logger = logging.getLogger(__name__)
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in tgi_env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
tgi_env_vars,
)
sys.exit(0)
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()