fix: prefer improved update_doc and start server and compare

This commit is contained in:
drbh 2024-07-02 15:17:43 +00:00
parent 7b34ba3408
commit 36bd48c293
4 changed files with 146 additions and 68 deletions

View File

@ -16,40 +16,3 @@ repos:
- id: fmt
- id: cargo-check
- id: clippy
- repo: local
hooks:
- id: check-openapi-update
name: check openapi spec update
entry: python
language: system
pass_filenames: false
always_run: true
args:
- -c
- |
import os
import sys
import subprocess
def get_changed_files():
result = subprocess.run(['git', 'diff', '--name-only'], capture_output=True, text=True)
return result.stdout.splitlines()
changed_files = get_changed_files()
router_files = [f for f in changed_files if f.startswith('router/')]
if not router_files:
print("No router files changed. Skipping OpenAPI spec check.")
sys.exit(0)
openapi_file = 'docs/openapi.json'
if not os.path.exists(openapi_file):
print(f"Error: {openapi_file} does not exist.")
sys.exit(1)
if openapi_file not in changed_files:
print(f"Error: Router files were updated, but {openapi_file} was not updated.")
sys.exit(1)
else:
print(f"{openapi_file} has been updated along with router changes.")
sys.exit(0)

View File

@ -83,8 +83,6 @@ struct Args {
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env, default_value_t = false)]
update_openapi_schema: bool,
}
#[tokio::main]
@ -121,7 +119,6 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
update_openapi_schema,
} = args;
// Launch Tokio runtime
@ -391,7 +388,6 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
update_openapi_schema,
)
.await?;
Ok(())

View File

@ -1430,7 +1430,6 @@ pub async fn run(
messages_api_enabled: bool,
grammar_support: bool,
max_client_batch_size: usize,
update_openapi_schema: bool,
) -> Result<(), WebServerError> {
// OpenAPI documentation
#[derive(OpenApi)]
@ -1500,25 +1499,6 @@ pub async fn run(
)]
struct ApiDoc;
if update_openapi_schema {
use std::io::Write;
let cargo_workspace =
std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
println!("workspace {}", cargo_workspace);
let output_file = format!("{}/../docs/openapi.json", cargo_workspace);
println!("output file {}", output_file);
let openapi = ApiDoc::openapi();
let mut file = std::fs::File::create(output_file).expect("Unable to create file");
file.write_all(
openapi
.to_pretty_json()
.expect("Unable to serialize OpenAPI")
.as_bytes(),
)
.expect("Unable to write data");
}
// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
Arc<dyn Scheduler + Send + Sync>,

View File

@ -1,6 +1,10 @@
import subprocess
import argparse
import ast
import requests
import json
import time
import os
TEMPLATE = """
# Supported Models and Hardware
@ -122,15 +126,150 @@ def check_supported_models(check: bool):
f.write(final_doc)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
def start_server_and_wait():
log_file = open("/tmp/server_log.txt", "w")
args = parser.parse_args()
process = subprocess.Popen(
["text-generation-launcher"],
stdout=log_file,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
print("Server is starting...")
check_cli(args.check)
check_supported_models(args.check)
start_time = time.time()
while True:
try:
response = requests.get("http://127.0.0.1:3000/health")
if response.status_code == 200:
print("Server is up and running!")
return process, log_file
except requests.RequestException:
if time.time() - start_time > 60:
log_file.close()
with open("server_log.txt", "r") as f:
print("Server log:")
print(f.read())
os.remove("server_log.txt")
raise TimeoutError("Server didn't start within 60 seconds")
time.sleep(1)
def stop_server(process, log_file, show=False):
process.terminate()
process.wait()
log_file.close()
if show:
with open("/tmp/server_log.txt", "r") as f:
print("Server log:")
print(f.read())
os.remove("/tmp/server_log.txt")
def get_openapi_json():
response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json")
# error if not 200
response.raise_for_status()
return response.json()
def update_openapi_json(new_data, filename="docs/openapi.json"):
with open(filename, "w") as f:
json.dump(new_data, f, indent=2)
def compare_openapi(old_data, new_data):
differences = []
def compare_recursive(old, new, path=""):
if isinstance(old, dict) and isinstance(new, dict):
for key in set(old.keys()) | set(new.keys()):
new_path = f"{path}.{key}" if path else key
if key not in old:
differences.append(f"Added: {new_path}")
elif key not in new:
differences.append(f"Removed: {new_path}")
else:
compare_recursive(old[key], new[key], new_path)
elif old != new:
differences.append(f"Changed: {path}")
compare_recursive(old_data, new_data)
return differences
def openapi(check: bool):
try:
server_process, log_file = start_server_and_wait()
try:
new_openapi_data = get_openapi_json()
if check:
try:
with open("docs/openapi.json", "r") as f:
old_openapi_data = json.load(f)
except FileNotFoundError:
print(
"docs/openapi.json not found. Run without --check to create it."
)
return
differences = compare_openapi(old_openapi_data, new_openapi_data)
if differences:
print("The following differences were found:")
for diff in differences:
print(diff)
print(
"Please run the script without --check to update the documentation."
)
else:
print("Documentation is up to date.")
else:
update_openapi_json(new_openapi_data)
print("Documentation updated successfully.")
finally:
stop_server(server_process, log_file)
except TimeoutError as e:
print(f"Error: {e}")
except requests.RequestException as e:
print(f"Error communicating with the server: {e}")
except json.JSONDecodeError:
print("Error: Invalid JSON received from the server")
except Exception as e:
print(f"An unexpected error occurred: {e}")
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(
description="Update documentation for text-generation-launcher"
)
subparsers = parser.add_subparsers(dest="command", required=True)
openapi_parser = subparsers.add_parser(
"openapi", help="Update OpenAPI documentation"
)
openapi_parser.add_argument(
"--check",
action="store_true",
help="Check if the OpenAPI documentation needs updating",
)
md_parser = subparsers.add_parser("md", help="Update launcher and supported models")
md_parser.add_argument(
"--check",
action="store_true",
help="Check if the launcher documentation needs updating",
)
args = parser.parse_args()
if args.command == "openapi":
openapi(args)
elif args.command == "md":
check_cli(args.check)
check_supported_models(args.check)