mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
import sys
|
|
from typing import Optional
|
|
|
|
import typer
|
|
from loguru import logger
|
|
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
@app.command()
|
|
def serve(
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
sharded: bool = False,
|
|
trust_remote_code: bool = None,
|
|
uds_path: str = "/tmp/text-generation-server",
|
|
logger_level: str = "INFO",
|
|
json_output: bool = False,
|
|
otlp_endpoint: Optional[str] = None,
|
|
otlp_service_name: str = "text-generation-inference.server",
|
|
max_input_tokens: Optional[int] = None,
|
|
):
|
|
"""This is the main entry-point for the server CLI.
|
|
|
|
Args:
|
|
model_id (`str`):
|
|
The *model_id* of a model on the HuggingFace hub or the path to a local model.
|
|
revision (`Optional[str]`, defaults to `None`):
|
|
The revision of the model on the HuggingFace hub.
|
|
sharded (`bool`):
|
|
Whether the model must be sharded or not. Kept for compatibility with the
|
|
text-generation-launcher, but must be set to False.
|
|
trust-remote-code (`bool`):
|
|
Kept for compatibility with text-generation-launcher. Ignored.
|
|
uds_path (`Union[Path, str]`):
|
|
The local path on which the server will expose its google RPC services.
|
|
logger_level (`str`):
|
|
The server logger level. Defaults to *INFO*.
|
|
json_output (`bool`):
|
|
Use JSON format for log serialization.
|
|
otlp_endpoint (`Optional[str]`, defaults to `None`):
|
|
The Open Telemetry endpoint to use.
|
|
otlp_service_name (`Optional[str]`, defaults to `None`):
|
|
The name to use when pushing data to the Open Telemetry endpoint.
|
|
max_input_tokens (`Optional[int]`, defaults to `None`):
|
|
The maximum number of input tokens each request should contain.
|
|
"""
|
|
if sharded:
|
|
raise ValueError("Sharding is not supported.")
|
|
# Remove default handler
|
|
logger.remove()
|
|
logger.add(
|
|
sys.stdout,
|
|
format="{message}",
|
|
filter="text_generation_server",
|
|
level=logger_level,
|
|
serialize=json_output,
|
|
backtrace=True,
|
|
diagnose=False,
|
|
)
|
|
|
|
if trust_remote_code is not None:
|
|
logger.warning(
|
|
"'trust_remote_code' argument is not supported and will be ignored."
|
|
)
|
|
|
|
# Import here after the logger is added to log potential import exceptions
|
|
from .server import serve
|
|
|
|
serve(model_id, revision, uds_path)
|
|
|
|
|
|
@app.command()
|
|
def download_weights(
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
logger_level: str = "INFO",
|
|
json_output: bool = False,
|
|
auto_convert: Optional[bool] = None,
|
|
extension: Optional[str] = None,
|
|
trust_remote_code: Optional[bool] = None,
|
|
merge_lora: Optional[bool] = None,
|
|
):
|
|
"""Download the model weights.
|
|
|
|
This command will be called by text-generation-launcher before serving the model.
|
|
"""
|
|
# Remove default handler
|
|
logger.remove()
|
|
logger.add(
|
|
sys.stdout,
|
|
format="{message}",
|
|
filter="text_generation_server",
|
|
level=logger_level,
|
|
serialize=json_output,
|
|
backtrace=True,
|
|
diagnose=False,
|
|
)
|
|
|
|
if extension is not None:
|
|
logger.warning("'extension' argument is not supported and will be ignored.")
|
|
if trust_remote_code is not None:
|
|
logger.warning(
|
|
"'trust_remote_code' argument is not supported and will be ignored."
|
|
)
|
|
if auto_convert is not None:
|
|
logger.warning("'auto_convert' argument is not supported and will be ignored.")
|
|
if merge_lora is not None:
|
|
logger.warning("'merge_lora' argument is not supported and will be ignored.")
|
|
|
|
# Import here after the logger is added to log potential import exceptions
|
|
from .model import fetch_model
|
|
|
|
fetch_model(model_id, revision)
|