mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-24 10:00:16 +00:00
* chore(neuron): bump version to 0.2.0 * refactor(neuron): use named parameters in inputs helpers This allows to hide the differences between the two backends in terms of input parameters. * refactor(neuron): remove obsolete code paths * fix(neuron): use neuron_config whenever possible * fix(neuron): use new cache import path * fix(neuron): neuron config is not stored in config anymore * fix(nxd): adapt model retrieval to new APIs * fix(generator): emulate greedy in sampling parameters When on-device sampling is enabled, we need to emulate the greedy behaviour using top-k=1, top-p=1, temperature=1. * test(neuron): update models and expectations * feat(neuron): support on-device sampling * fix(neuron): adapt entrypoint * tests(neuron): remove obsolete models * fix(neuron): adjust test expectations for llama on nxd
54 lines
1.2 KiB
Python
Executable File
54 lines
1.2 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
|
|
from text_generation_server.tgi_env import (
|
|
available_cores,
|
|
get_env_dict,
|
|
get_neuron_config_for_model,
|
|
neuron_config_to_env,
|
|
neuronxcc_version,
|
|
parse_cmdline_and_set_env,
|
|
tgi_env_vars,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def main():
|
|
"""
|
|
This script determines proper default TGI env variables for the neuron precompiled models to
|
|
work properly
|
|
:return:
|
|
"""
|
|
args = parse_cmdline_and_set_env()
|
|
|
|
for env_var in tgi_env_vars:
|
|
if not os.getenv(env_var):
|
|
break
|
|
else:
|
|
logger.info(
|
|
"All env vars %s already set, skipping, user know what they are doing",
|
|
tgi_env_vars,
|
|
)
|
|
sys.exit(0)
|
|
|
|
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
|
|
|
|
if not neuron_config:
|
|
msg = (
|
|
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
|
).format(get_env_dict(), available_cores, neuronxcc_version)
|
|
logger.error(msg)
|
|
raise Exception(msg)
|
|
|
|
neuron_config_to_env(neuron_config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|