text-generation-inference/server/text_generation/cli.py

45 lines
1.0 KiB
Python
Raw Normal View History

2022-10-18 13:19:03 +00:00
import os
2022-10-17 12:59:00 +00:00
import typer
from pathlib import Path
from text_generation import server, utils
2022-10-17 12:59:00 +00:00
app = typer.Typer()
@app.command()
2022-10-18 13:19:03 +00:00
def serve(
model_name: str,
sharded: bool = False,
2022-10-27 12:25:29 +00:00
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
2022-10-17 12:59:00 +00:00
):
2022-10-18 13:19:03 +00:00
if sharded:
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
2022-10-27 12:25:29 +00:00
server.serve(model_name, sharded, quantize, uds_path)
2022-10-17 12:59:00 +00:00
@app.command()
def download_weights(
2022-10-18 13:19:03 +00:00
model_name: str,
extension: str = ".safetensors",
2022-10-17 12:59:00 +00:00
):
utils.download_weights(model_name, extension)
2022-10-17 12:59:00 +00:00
if __name__ == "__main__":
app()