mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: improve update doc and add command to print router schema
This commit is contained in:
parent
caa44012ad
commit
e4161a185f
13
.github/workflows/autodocs.yaml
vendored
13
.github/workflows/autodocs.yaml
vendored
@ -20,16 +20,9 @@ jobs:
|
||||
echo text-generation-launcher --help
|
||||
python update_doc.py md --check
|
||||
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v1
|
||||
- name: Clean unused files
|
||||
run: |
|
||||
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
|
||||
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
||||
|
||||
- name: Install
|
||||
run: |
|
||||
make install-cpu
|
||||
- name: Install router
|
||||
id: install-router
|
||||
run: cargo install --path router/
|
||||
|
||||
- name: Check that openapi schema is up-to-date
|
||||
run: |
|
||||
|
@ -1,5 +1,6 @@
|
||||
use axum::http::HeaderValue;
|
||||
use clap::Parser;
|
||||
use clap::Subcommand;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
@ -85,10 +89,15 @@ struct Args {
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
command,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
let print_schema_command = match command {
|
||||
Some(Commands::PrintSchema) => true,
|
||||
None => {
|
||||
// only init logging if we are not running the print schema command
|
||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
print_schema_command,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -1387,10 +1387,10 @@ async fn tokenize(
|
||||
|
||||
/// Prometheus metrics scrape endpoint
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/metrics",
|
||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/metrics",
|
||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||
)]
|
||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
@ -1430,6 +1430,7 @@ pub async fn run(
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
print_schema_command: bool,
|
||||
) -> Result<(), WebServerError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -1500,6 +1501,12 @@ pub async fn run(
|
||||
struct ApiDoc;
|
||||
|
||||
// Create state
|
||||
if print_schema_command {
|
||||
let api_doc = ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
// Open connection, get model info and warmup
|
||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||
|
184
update_doc.py
184
update_doc.py
@ -1,9 +1,7 @@
|
||||
import subprocess
|
||||
import argparse
|
||||
import ast
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
TEMPLATE = """
|
||||
@ -126,155 +124,55 @@ def check_supported_models(check: bool):
|
||||
f.write(final_doc)
|
||||
|
||||
|
||||
def start_server_and_wait():
|
||||
log_file = open("/tmp/server_log.txt", "w")
|
||||
|
||||
process = subprocess.Popen(
|
||||
["text-generation-launcher"],
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
)
|
||||
print("Server is starting...")
|
||||
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get("http://127.0.0.1:3000/health")
|
||||
if response.status_code == 200:
|
||||
print("Server is up and running!")
|
||||
return process, log_file
|
||||
except requests.RequestException:
|
||||
# timeout after 3 minutes (CI can be slow sometimes)
|
||||
if time.time() - start_time > 180:
|
||||
log_file.close()
|
||||
with open("/tmp/server_log.txt", "r") as f:
|
||||
print("Server log:")
|
||||
print(f.read())
|
||||
os.remove("/tmp/server_log.txt")
|
||||
raise TimeoutError("Server didn't start within 60 seconds")
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def stop_server(process, log_file, show=False):
|
||||
process.terminate()
|
||||
process.wait()
|
||||
log_file.close()
|
||||
|
||||
if show:
|
||||
with open("/tmp/server_log.txt", "r") as f:
|
||||
print("Server log:")
|
||||
print(f.read())
|
||||
os.remove("/tmp/server_log.txt")
|
||||
|
||||
|
||||
def get_openapi_json():
|
||||
response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json")
|
||||
# error if not 200
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def update_openapi_json(new_data, filename="docs/openapi.json"):
|
||||
with open(filename, "w") as f:
|
||||
json.dump(new_data, f, indent=2)
|
||||
|
||||
|
||||
def compare_openapi(old_data, new_data):
|
||||
differences = []
|
||||
|
||||
def compare_recursive(old, new, path=""):
|
||||
if isinstance(old, dict) and isinstance(new, dict):
|
||||
for key in set(old.keys()) | set(new.keys()):
|
||||
new_path = f"{path}.{key}" if path else key
|
||||
if key not in old:
|
||||
differences.append(f"Added: {new_path}")
|
||||
elif key not in new:
|
||||
differences.append(f"Removed: {new_path}")
|
||||
else:
|
||||
compare_recursive(old[key], new[key], new_path)
|
||||
elif old != new:
|
||||
differences.append(f"Changed: {path}")
|
||||
|
||||
compare_recursive(old_data, new_data)
|
||||
return differences
|
||||
|
||||
|
||||
def openapi(check: bool):
|
||||
def get_openapi_schema():
|
||||
try:
|
||||
server_process, log_file = start_server_and_wait()
|
||||
|
||||
try:
|
||||
new_openapi_data = get_openapi_json()
|
||||
|
||||
if check:
|
||||
try:
|
||||
with open("docs/openapi.json", "r") as f:
|
||||
old_openapi_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"docs/openapi.json not found. Run without --check to create it."
|
||||
)
|
||||
return
|
||||
|
||||
differences = compare_openapi(old_openapi_data, new_openapi_data)
|
||||
|
||||
if differences:
|
||||
print("The following differences were found:")
|
||||
for diff in differences:
|
||||
print(diff)
|
||||
print(
|
||||
"Please run the script without --check to update the documentation."
|
||||
)
|
||||
else:
|
||||
print("Documentation is up to date.")
|
||||
else:
|
||||
update_openapi_json(new_openapi_data)
|
||||
print("Documentation updated successfully.")
|
||||
|
||||
finally:
|
||||
stop_server(server_process, log_file)
|
||||
|
||||
except TimeoutError as e:
|
||||
print(f"Error: {e}")
|
||||
raise SystemExit(1)
|
||||
except requests.RequestException as e:
|
||||
print(f"Error communicating with the server: {e}")
|
||||
output = subprocess.check_output(["text-generation-router", "print-schema"])
|
||||
return json.loads(output)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running text-generation-router print-schema: {e}")
|
||||
raise SystemExit(1)
|
||||
except json.JSONDecodeError:
|
||||
print("Error: Invalid JSON received from the server")
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
print("Error: Invalid JSON received from text-generation-router print-schema")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Update documentation for text-generation-launcher"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
def check_openapi(check: bool):
|
||||
new_openapi_data = get_openapi_schema()
|
||||
filename = "docs/openapi.json"
|
||||
tmp_filename = "openapi_tmp.json"
|
||||
|
||||
openapi_parser = subparsers.add_parser(
|
||||
"openapi", help="Update OpenAPI documentation"
|
||||
)
|
||||
openapi_parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Check if the OpenAPI documentation needs updating",
|
||||
)
|
||||
with open(tmp_filename, "w") as f:
|
||||
json.dump(new_openapi_data, f, indent=2)
|
||||
|
||||
md_parser = subparsers.add_parser("md", help="Update launcher and supported models")
|
||||
md_parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Check if the launcher documentation needs updating",
|
||||
)
|
||||
if check:
|
||||
diff = subprocess.run(
|
||||
["diff", tmp_filename, filename], capture_output=True
|
||||
).stdout.decode()
|
||||
os.remove(tmp_filename)
|
||||
|
||||
if diff:
|
||||
print(diff)
|
||||
raise Exception(
|
||||
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
os.rename(tmp_filename, filename)
|
||||
print("OpenAPI documentation updated.")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "openapi":
|
||||
openapi(args.check)
|
||||
elif args.command == "md":
|
||||
check_cli(args.check)
|
||||
check_supported_models(args.check)
|
||||
check_cli(args.check)
|
||||
check_supported_models(args.check)
|
||||
check_openapi(args.check)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user