mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: prefer improved update_doc and start server and compare
This commit is contained in:
parent
7b34ba3408
commit
36bd48c293
@ -16,40 +16,3 @@ repos:
|
||||
- id: fmt
|
||||
- id: cargo-check
|
||||
- id: clippy
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-openapi-update
|
||||
name: check openapi spec update
|
||||
entry: python
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
def get_changed_files():
|
||||
result = subprocess.run(['git', 'diff', '--name-only'], capture_output=True, text=True)
|
||||
return result.stdout.splitlines()
|
||||
|
||||
changed_files = get_changed_files()
|
||||
router_files = [f for f in changed_files if f.startswith('router/')]
|
||||
|
||||
if not router_files:
|
||||
print("No router files changed. Skipping OpenAPI spec check.")
|
||||
sys.exit(0)
|
||||
|
||||
openapi_file = 'docs/openapi.json'
|
||||
if not os.path.exists(openapi_file):
|
||||
print(f"Error: {openapi_file} does not exist.")
|
||||
sys.exit(1)
|
||||
|
||||
if openapi_file not in changed_files:
|
||||
print(f"Error: Router files were updated, but {openapi_file} was not updated.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"{openapi_file} has been updated along with router changes.")
|
||||
sys.exit(0)
|
||||
|
@ -83,8 +83,6 @@ struct Args {
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
update_openapi_schema: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -121,7 +119,6 @@ async fn main() -> Result<(), RouterError> {
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
update_openapi_schema,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
@ -391,7 +388,6 @@ async fn main() -> Result<(), RouterError> {
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
update_openapi_schema,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -1430,7 +1430,6 @@ pub async fn run(
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
update_openapi_schema: bool,
|
||||
) -> Result<(), WebServerError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -1500,25 +1499,6 @@ pub async fn run(
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
if update_openapi_schema {
|
||||
use std::io::Write;
|
||||
let cargo_workspace =
|
||||
std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
|
||||
println!("workspace {}", cargo_workspace);
|
||||
let output_file = format!("{}/../docs/openapi.json", cargo_workspace);
|
||||
println!("output file {}", output_file);
|
||||
|
||||
let openapi = ApiDoc::openapi();
|
||||
let mut file = std::fs::File::create(output_file).expect("Unable to create file");
|
||||
file.write_all(
|
||||
openapi
|
||||
.to_pretty_json()
|
||||
.expect("Unable to serialize OpenAPI")
|
||||
.as_bytes(),
|
||||
)
|
||||
.expect("Unable to write data");
|
||||
}
|
||||
|
||||
// Open connection, get model info and warmup
|
||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||
Arc<dyn Scheduler + Send + Sync>,
|
||||
|
153
update_doc.py
153
update_doc.py
@ -1,6 +1,10 @@
|
||||
import subprocess
|
||||
import argparse
|
||||
import ast
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
TEMPLATE = """
|
||||
# Supported Models and Hardware
|
||||
@ -122,15 +126,150 @@ def check_supported_models(check: bool):
|
||||
f.write(final_doc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
def start_server_and_wait():
|
||||
log_file = open("/tmp/server_log.txt", "w")
|
||||
|
||||
args = parser.parse_args()
|
||||
process = subprocess.Popen(
|
||||
["text-generation-launcher"],
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
)
|
||||
print("Server is starting...")
|
||||
|
||||
check_cli(args.check)
|
||||
check_supported_models(args.check)
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get("http://127.0.0.1:3000/health")
|
||||
if response.status_code == 200:
|
||||
print("Server is up and running!")
|
||||
return process, log_file
|
||||
except requests.RequestException:
|
||||
if time.time() - start_time > 60:
|
||||
log_file.close()
|
||||
with open("server_log.txt", "r") as f:
|
||||
print("Server log:")
|
||||
print(f.read())
|
||||
os.remove("server_log.txt")
|
||||
raise TimeoutError("Server didn't start within 60 seconds")
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def stop_server(process, log_file, show=False):
|
||||
process.terminate()
|
||||
process.wait()
|
||||
log_file.close()
|
||||
|
||||
if show:
|
||||
with open("/tmp/server_log.txt", "r") as f:
|
||||
print("Server log:")
|
||||
print(f.read())
|
||||
os.remove("/tmp/server_log.txt")
|
||||
|
||||
|
||||
def get_openapi_json():
|
||||
response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json")
|
||||
# error if not 200
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def update_openapi_json(new_data, filename="docs/openapi.json"):
|
||||
with open(filename, "w") as f:
|
||||
json.dump(new_data, f, indent=2)
|
||||
|
||||
|
||||
def compare_openapi(old_data, new_data):
|
||||
differences = []
|
||||
|
||||
def compare_recursive(old, new, path=""):
|
||||
if isinstance(old, dict) and isinstance(new, dict):
|
||||
for key in set(old.keys()) | set(new.keys()):
|
||||
new_path = f"{path}.{key}" if path else key
|
||||
if key not in old:
|
||||
differences.append(f"Added: {new_path}")
|
||||
elif key not in new:
|
||||
differences.append(f"Removed: {new_path}")
|
||||
else:
|
||||
compare_recursive(old[key], new[key], new_path)
|
||||
elif old != new:
|
||||
differences.append(f"Changed: {path}")
|
||||
|
||||
compare_recursive(old_data, new_data)
|
||||
return differences
|
||||
|
||||
|
||||
def openapi(check: bool):
|
||||
try:
|
||||
server_process, log_file = start_server_and_wait()
|
||||
|
||||
try:
|
||||
new_openapi_data = get_openapi_json()
|
||||
|
||||
if check:
|
||||
try:
|
||||
with open("docs/openapi.json", "r") as f:
|
||||
old_openapi_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"docs/openapi.json not found. Run without --check to create it."
|
||||
)
|
||||
return
|
||||
|
||||
differences = compare_openapi(old_openapi_data, new_openapi_data)
|
||||
|
||||
if differences:
|
||||
print("The following differences were found:")
|
||||
for diff in differences:
|
||||
print(diff)
|
||||
print(
|
||||
"Please run the script without --check to update the documentation."
|
||||
)
|
||||
else:
|
||||
print("Documentation is up to date.")
|
||||
else:
|
||||
update_openapi_json(new_openapi_data)
|
||||
print("Documentation updated successfully.")
|
||||
|
||||
finally:
|
||||
stop_server(server_process, log_file)
|
||||
|
||||
except TimeoutError as e:
|
||||
print(f"Error: {e}")
|
||||
except requests.RequestException as e:
|
||||
print(f"Error communicating with the server: {e}")
|
||||
except json.JSONDecodeError:
|
||||
print("Error: Invalid JSON received from the server")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Update documentation for text-generation-launcher"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
openapi_parser = subparsers.add_parser(
|
||||
"openapi", help="Update OpenAPI documentation"
|
||||
)
|
||||
openapi_parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Check if the OpenAPI documentation needs updating",
|
||||
)
|
||||
|
||||
md_parser = subparsers.add_parser("md", help="Update launcher and supported models")
|
||||
md_parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Check if the launcher documentation needs updating",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "openapi":
|
||||
openapi(args)
|
||||
elif args.command == "md":
|
||||
check_cli(args.check)
|
||||
check_supported_models(args.check)
|
||||
|
Loading…
Reference in New Issue
Block a user