mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-08 18:32:06 +00:00
54 lines
1.2 KiB
Python
54 lines
1.2 KiB
Python
|
#!/usr/bin/env python
|
||
|
|
||
|
import logging
|
||
|
import os
|
||
|
import sys
|
||
|
|
||
|
|
||
|
from text_generation_server.tgi_env import (
|
||
|
available_cores,
|
||
|
get_env_dict,
|
||
|
get_neuron_config_for_model,
|
||
|
neuron_config_to_env,
|
||
|
neuronxcc_version,
|
||
|
parse_cmdline_and_set_env,
|
||
|
tgi_env_vars,
|
||
|
)
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
"""
|
||
|
This script determines proper default TGI env variables for the neuron precompiled models to
|
||
|
work properly
|
||
|
:return:
|
||
|
"""
|
||
|
args = parse_cmdline_and_set_env()
|
||
|
|
||
|
for env_var in tgi_env_vars:
|
||
|
if not os.getenv(env_var):
|
||
|
break
|
||
|
else:
|
||
|
logger.info(
|
||
|
"All env vars %s already set, skipping, user know what they are doing",
|
||
|
tgi_env_vars,
|
||
|
)
|
||
|
sys.exit(0)
|
||
|
|
||
|
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
|
||
|
|
||
|
if not neuron_config:
|
||
|
msg = (
|
||
|
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
||
|
).format(get_env_dict(), available_cores, neuronxcc_version)
|
||
|
logger.error(msg)
|
||
|
raise Exception(msg)
|
||
|
|
||
|
neuron_config_to_env(neuron_config)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|