mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import os
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
import sys
|
|
from text_generation_server import server
|
|
import argparse
|
|
|
|
|
|
def main(args):
|
|
logger.info("TGIService: starting tgi service .... ")
|
|
logger.info(
|
|
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
|
|
args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path
|
|
)
|
|
)
|
|
server.serve(
|
|
model_id=args.model_id,
|
|
revision=args.revision,
|
|
sharded=args.sharded,
|
|
speculate=args.speculate,
|
|
dtype=args.dtype,
|
|
trust_remote_code=args.trust_remote_code,
|
|
uds_path=args.uds_path,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_id", type=str)
|
|
parser.add_argument("--revision", type=str)
|
|
parser.add_argument("--sharded", type=bool)
|
|
parser.add_argument("--speculate", type=int, default=None)
|
|
parser.add_argument("--dtype", type=str)
|
|
parser.add_argument("--trust_remote_code", type=bool)
|
|
parser.add_argument("--uds_path", type=Path)
|
|
args = parser.parse_args()
|
|
main(args)
|