mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
* wip(gaudi): import server and dockerfile from tgi-gaudi fork * feat(gaudi): new gaudi backend working * fix: fix style * fix prehooks issues * fix(gaudi): refactor server and implement requested changes
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import os
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
from text_generation_server import server
|
|
import argparse
|
|
from text_generation_server.utils.adapter import parse_lora_adapters
|
|
|
|
|
|
def main(args):
|
|
logger.info("TGIService: starting tgi service .... ")
|
|
logger.info(
|
|
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
|
|
args.model_id,
|
|
args.revision,
|
|
args.sharded,
|
|
args.speculate,
|
|
args.dtype,
|
|
args.trust_remote_code,
|
|
args.uds_path,
|
|
)
|
|
)
|
|
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
|
server.serve(
|
|
model_id=args.model_id,
|
|
lora_adapters=lora_adapters,
|
|
revision=args.revision,
|
|
sharded=args.sharded,
|
|
quantize=args.quantize,
|
|
speculate=args.speculate,
|
|
dtype=args.dtype,
|
|
trust_remote_code=args.trust_remote_code,
|
|
uds_path=args.uds_path,
|
|
max_input_tokens=args.max_input_tokens,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_id", type=str)
|
|
parser.add_argument("--revision", type=str)
|
|
parser.add_argument("--sharded", type=bool)
|
|
parser.add_argument("--speculate", type=int, default=None)
|
|
parser.add_argument("--dtype", type=str)
|
|
parser.add_argument("--trust_remote_code", type=bool)
|
|
parser.add_argument("--uds_path", type=Path)
|
|
parser.add_argument("--quantize", type=str)
|
|
parser.add_argument("--max_input_tokens", type=int)
|
|
args = parser.parse_args()
|
|
main(args)
|