mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: prefer improved update_doc and start server and compare
This commit is contained in:
parent
7b34ba3408
commit
36bd48c293
@ -16,40 +16,3 @@ repos:
|
|||||||
- id: fmt
|
- id: fmt
|
||||||
- id: cargo-check
|
- id: cargo-check
|
||||||
- id: clippy
|
- id: clippy
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: check-openapi-update
|
|
||||||
name: check openapi spec update
|
|
||||||
entry: python
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
always_run: true
|
|
||||||
args:
|
|
||||||
- -c
|
|
||||||
- |
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
def get_changed_files():
|
|
||||||
result = subprocess.run(['git', 'diff', '--name-only'], capture_output=True, text=True)
|
|
||||||
return result.stdout.splitlines()
|
|
||||||
|
|
||||||
changed_files = get_changed_files()
|
|
||||||
router_files = [f for f in changed_files if f.startswith('router/')]
|
|
||||||
|
|
||||||
if not router_files:
|
|
||||||
print("No router files changed. Skipping OpenAPI spec check.")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
openapi_file = 'docs/openapi.json'
|
|
||||||
if not os.path.exists(openapi_file):
|
|
||||||
print(f"Error: {openapi_file} does not exist.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if openapi_file not in changed_files:
|
|
||||||
print(f"Error: Router files were updated, but {openapi_file} was not updated.")
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
print(f"{openapi_file} has been updated along with router changes.")
|
|
||||||
sys.exit(0)
|
|
||||||
|
@ -83,8 +83,6 @@ struct Args {
|
|||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
update_openapi_schema: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -121,7 +119,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
update_openapi_schema,
|
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
@ -391,7 +388,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
update_openapi_schema,
|
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1430,7 +1430,6 @@ pub async fn run(
|
|||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
grammar_support: bool,
|
grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
update_openapi_schema: bool,
|
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -1500,25 +1499,6 @@ pub async fn run(
|
|||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
if update_openapi_schema {
|
|
||||||
use std::io::Write;
|
|
||||||
let cargo_workspace =
|
|
||||||
std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
|
|
||||||
println!("workspace {}", cargo_workspace);
|
|
||||||
let output_file = format!("{}/../docs/openapi.json", cargo_workspace);
|
|
||||||
println!("output file {}", output_file);
|
|
||||||
|
|
||||||
let openapi = ApiDoc::openapi();
|
|
||||||
let mut file = std::fs::File::create(output_file).expect("Unable to create file");
|
|
||||||
file.write_all(
|
|
||||||
openapi
|
|
||||||
.to_pretty_json()
|
|
||||||
.expect("Unable to serialize OpenAPI")
|
|
||||||
.as_bytes(),
|
|
||||||
)
|
|
||||||
.expect("Unable to write data");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open connection, get model info and warmup
|
// Open connection, get model info and warmup
|
||||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||||
Arc<dyn Scheduler + Send + Sync>,
|
Arc<dyn Scheduler + Send + Sync>,
|
||||||
|
153
update_doc.py
153
update_doc.py
@ -1,6 +1,10 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
TEMPLATE = """
|
TEMPLATE = """
|
||||||
# Supported Models and Hardware
|
# Supported Models and Hardware
|
||||||
@ -122,15 +126,150 @@ def check_supported_models(check: bool):
|
|||||||
f.write(final_doc)
|
f.write(final_doc)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def start_server_and_wait():
|
||||||
parser = argparse.ArgumentParser()
|
log_file = open("/tmp/server_log.txt", "w")
|
||||||
parser.add_argument("--check", action="store_true")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
process = subprocess.Popen(
|
||||||
|
["text-generation-launcher"],
|
||||||
|
stdout=log_file,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
universal_newlines=True,
|
||||||
|
)
|
||||||
|
print("Server is starting...")
|
||||||
|
|
||||||
check_cli(args.check)
|
start_time = time.time()
|
||||||
check_supported_models(args.check)
|
while True:
|
||||||
|
try:
|
||||||
|
response = requests.get("http://127.0.0.1:3000/health")
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Server is up and running!")
|
||||||
|
return process, log_file
|
||||||
|
except requests.RequestException:
|
||||||
|
if time.time() - start_time > 60:
|
||||||
|
log_file.close()
|
||||||
|
with open("server_log.txt", "r") as f:
|
||||||
|
print("Server log:")
|
||||||
|
print(f.read())
|
||||||
|
os.remove("server_log.txt")
|
||||||
|
raise TimeoutError("Server didn't start within 60 seconds")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_server(process, log_file, show=False):
|
||||||
|
process.terminate()
|
||||||
|
process.wait()
|
||||||
|
log_file.close()
|
||||||
|
|
||||||
|
if show:
|
||||||
|
with open("/tmp/server_log.txt", "r") as f:
|
||||||
|
print("Server log:")
|
||||||
|
print(f.read())
|
||||||
|
os.remove("/tmp/server_log.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_json():
|
||||||
|
response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json")
|
||||||
|
# error if not 200
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def update_openapi_json(new_data, filename="docs/openapi.json"):
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(new_data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_openapi(old_data, new_data):
|
||||||
|
differences = []
|
||||||
|
|
||||||
|
def compare_recursive(old, new, path=""):
|
||||||
|
if isinstance(old, dict) and isinstance(new, dict):
|
||||||
|
for key in set(old.keys()) | set(new.keys()):
|
||||||
|
new_path = f"{path}.{key}" if path else key
|
||||||
|
if key not in old:
|
||||||
|
differences.append(f"Added: {new_path}")
|
||||||
|
elif key not in new:
|
||||||
|
differences.append(f"Removed: {new_path}")
|
||||||
|
else:
|
||||||
|
compare_recursive(old[key], new[key], new_path)
|
||||||
|
elif old != new:
|
||||||
|
differences.append(f"Changed: {path}")
|
||||||
|
|
||||||
|
compare_recursive(old_data, new_data)
|
||||||
|
return differences
|
||||||
|
|
||||||
|
|
||||||
|
def openapi(check: bool):
|
||||||
|
try:
|
||||||
|
server_process, log_file = start_server_and_wait()
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_openapi_data = get_openapi_json()
|
||||||
|
|
||||||
|
if check:
|
||||||
|
try:
|
||||||
|
with open("docs/openapi.json", "r") as f:
|
||||||
|
old_openapi_data = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(
|
||||||
|
"docs/openapi.json not found. Run without --check to create it."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
differences = compare_openapi(old_openapi_data, new_openapi_data)
|
||||||
|
|
||||||
|
if differences:
|
||||||
|
print("The following differences were found:")
|
||||||
|
for diff in differences:
|
||||||
|
print(diff)
|
||||||
|
print(
|
||||||
|
"Please run the script without --check to update the documentation."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Documentation is up to date.")
|
||||||
|
else:
|
||||||
|
update_openapi_json(new_openapi_data)
|
||||||
|
print("Documentation updated successfully.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
stop_server(server_process, log_file)
|
||||||
|
|
||||||
|
except TimeoutError as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Error communicating with the server: {e}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Error: Invalid JSON received from the server")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An unexpected error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Update documentation for text-generation-launcher"
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
openapi_parser = subparsers.add_parser(
|
||||||
|
"openapi", help="Update OpenAPI documentation"
|
||||||
|
)
|
||||||
|
openapi_parser.add_argument(
|
||||||
|
"--check",
|
||||||
|
action="store_true",
|
||||||
|
help="Check if the OpenAPI documentation needs updating",
|
||||||
|
)
|
||||||
|
|
||||||
|
md_parser = subparsers.add_parser("md", help="Update launcher and supported models")
|
||||||
|
md_parser.add_argument(
|
||||||
|
"--check",
|
||||||
|
action="store_true",
|
||||||
|
help="Check if the launcher documentation needs updating",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "openapi":
|
||||||
|
openapi(args)
|
||||||
|
elif args.command == "md":
|
||||||
|
check_cli(args.check)
|
||||||
|
check_supported_models(args.check)
|
||||||
|
Loading…
Reference in New Issue
Block a user