From 36bd48c293659311b9652e8589dea53088eea4e7 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 2 Jul 2024 15:17:43 +0000 Subject: [PATCH] fix: prefer improved update_doc and start server and compare --- .pre-commit-config.yaml | 37 ---------- router/src/main.rs | 4 -- router/src/server.rs | 20 ------ update_doc.py | 153 ++++++++++++++++++++++++++++++++++++++-- 4 files changed, 146 insertions(+), 68 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c40250d5..45bc07a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,40 +16,3 @@ repos: - id: fmt - id: cargo-check - id: clippy -- repo: local - hooks: - - id: check-openapi-update - name: check openapi spec update - entry: python - language: system - pass_filenames: false - always_run: true - args: - - -c - - | - import os - import sys - import subprocess - - def get_changed_files(): - result = subprocess.run(['git', 'diff', '--name-only'], capture_output=True, text=True) - return result.stdout.splitlines() - - changed_files = get_changed_files() - router_files = [f for f in changed_files if f.startswith('router/')] - - if not router_files: - print("No router files changed. Skipping OpenAPI spec check.") - sys.exit(0) - - openapi_file = 'docs/openapi.json' - if not os.path.exists(openapi_file): - print(f"Error: {openapi_file} does not exist.") - sys.exit(1) - - if openapi_file not in changed_files: - print(f"Error: Router files were updated, but {openapi_file} was not updated.") - sys.exit(1) - else: - print(f"{openapi_file} has been updated along with router changes.") - sys.exit(0) diff --git a/router/src/main.rs b/router/src/main.rs index 210cfca8..8618f57e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -83,8 +83,6 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, - #[clap(long, env, default_value_t = false)] - update_openapi_schema: bool, } #[tokio::main] @@ -121,7 +119,6 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, - update_openapi_schema, } = args; // Launch Tokio runtime @@ -391,7 +388,6 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, - update_openapi_schema, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 3268d50b..bb952f15 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1430,7 +1430,6 @@ pub async fn run( messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, - update_openapi_schema: bool, ) -> Result<(), WebServerError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1500,25 +1499,6 @@ pub async fn run( )] struct ApiDoc; - if update_openapi_schema { - use std::io::Write; - let cargo_workspace = - std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string()); - println!("workspace {}", cargo_workspace); - let output_file = format!("{}/../docs/openapi.json", cargo_workspace); - println!("output file {}", output_file); - - let openapi = ApiDoc::openapi(); - let mut file = std::fs::File::create(output_file).expect("Unable to create file"); - file.write_all( - openapi - .to_pretty_json() - .expect("Unable to serialize OpenAPI") - .as_bytes(), - ) - .expect("Unable to write data"); - } - // Open connection, get model info and warmup let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( Arc, diff --git a/update_doc.py b/update_doc.py index 5da81c72..d3132db7 100644 --- a/update_doc.py +++ b/update_doc.py @@ -1,6 +1,10 @@ import subprocess import argparse import ast +import requests +import json +import time +import os TEMPLATE = """ # Supported Models and Hardware @@ -122,15 +126,150 @@ def check_supported_models(check: bool): f.write(final_doc) -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--check", action="store_true") +def start_server_and_wait(): + log_file = open("/tmp/server_log.txt", "w") - args = parser.parse_args() + process = subprocess.Popen( + ["text-generation-launcher"], + stdout=log_file, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + print("Server is starting...") - check_cli(args.check) - check_supported_models(args.check) + start_time = time.time() + while True: + try: + response = requests.get("http://127.0.0.1:3000/health") + if response.status_code == 200: + print("Server is up and running!") + return process, log_file + except requests.RequestException: + if time.time() - start_time > 60: + log_file.close() + with open("server_log.txt", "r") as f: + print("Server log:") + print(f.read()) + os.remove("server_log.txt") + raise TimeoutError("Server didn't start within 60 seconds") + time.sleep(1) + + +def stop_server(process, log_file, show=False): + process.terminate() + process.wait() + log_file.close() + + if show: + with open("/tmp/server_log.txt", "r") as f: + print("Server log:") + print(f.read()) + os.remove("/tmp/server_log.txt") + + +def get_openapi_json(): + response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json") + # error if not 200 + response.raise_for_status() + return response.json() + + +def update_openapi_json(new_data, filename="docs/openapi.json"): + with open(filename, "w") as f: + json.dump(new_data, f, indent=2) + + +def compare_openapi(old_data, new_data): + differences = [] + + def compare_recursive(old, new, path=""): + if isinstance(old, dict) and isinstance(new, dict): + for key in set(old.keys()) | set(new.keys()): + new_path = f"{path}.{key}" if path else key + if key not in old: + differences.append(f"Added: {new_path}") + elif key not in new: + differences.append(f"Removed: {new_path}") + else: + compare_recursive(old[key], new[key], new_path) + elif old != new: + differences.append(f"Changed: {path}") + + compare_recursive(old_data, new_data) + return differences + + +def openapi(check: bool): + try: + server_process, log_file = start_server_and_wait() + + try: + new_openapi_data = get_openapi_json() + + if check: + try: + with open("docs/openapi.json", "r") as f: + old_openapi_data = json.load(f) + except FileNotFoundError: + print( + "docs/openapi.json not found. Run without --check to create it." + ) + return + + differences = compare_openapi(old_openapi_data, new_openapi_data) + + if differences: + print("The following differences were found:") + for diff in differences: + print(diff) + print( + "Please run the script without --check to update the documentation." + ) + else: + print("Documentation is up to date.") + else: + update_openapi_json(new_openapi_data) + print("Documentation updated successfully.") + + finally: + stop_server(server_process, log_file) + + except TimeoutError as e: + print(f"Error: {e}") + except requests.RequestException as e: + print(f"Error communicating with the server: {e}") + except json.JSONDecodeError: + print("Error: Invalid JSON received from the server") + except Exception as e: + print(f"An unexpected error occurred: {e}") if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="Update documentation for text-generation-launcher" + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + openapi_parser = subparsers.add_parser( + "openapi", help="Update OpenAPI documentation" + ) + openapi_parser.add_argument( + "--check", + action="store_true", + help="Check if the OpenAPI documentation needs updating", + ) + + md_parser = subparsers.add_parser("md", help="Update launcher and supported models") + md_parser.add_argument( + "--check", + action="store_true", + help="Check if the launcher documentation needs updating", + ) + + args = parser.parse_args() + + if args.command == "openapi": + openapi(args) + elif args.command == "md": + check_cli(args.check) + check_supported_models(args.check)