feat: improve update doc and add command to print router schema

This commit is contained in:
drbh 2024-07-02 18:01:29 +00:00
parent caa44012ad
commit e4161a185f
4 changed files with 75 additions and 160 deletions

View File

@ -20,16 +20,9 @@ jobs:
echo text-generation-launcher --help
python update_doc.py md --check
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Clean unused files
run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install
run: |
make install-cpu
- name: Install router
id: install-router
run: cargo install --path router/
- name: Check that openapi schema is up-to-date
run: |

View File

@ -1,5 +1,6 @@
use axum::http::HeaderValue;
use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
@ -85,10 +89,15 @@ struct Args {
max_client_batch_size: usize,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
command,
} = args;
// Launch Tokio runtime
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await?;
Ok(())

View File

@ -1387,10 +1387,10 @@ async fn tokenize(
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
@ -1430,6 +1430,7 @@ pub async fn run(
messages_api_enabled: bool,
grammar_support: bool,
max_client_batch_size: usize,
print_schema_command: bool,
) -> Result<(), WebServerError> {
// OpenAPI documentation
#[derive(OpenApi)]
@ -1500,6 +1501,12 @@ pub async fn run(
struct ApiDoc;
// Create state
if print_schema_command {
let api_doc = ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
}
// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (

View File

@ -1,9 +1,7 @@
import subprocess
import argparse
import ast
import requests
import json
import time
import os
TEMPLATE = """
@ -126,155 +124,55 @@ def check_supported_models(check: bool):
f.write(final_doc)
def start_server_and_wait():
log_file = open("/tmp/server_log.txt", "w")
process = subprocess.Popen(
["text-generation-launcher"],
stdout=log_file,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
print("Server is starting...")
start_time = time.time()
while True:
def get_openapi_schema():
try:
response = requests.get("http://127.0.0.1:3000/health")
if response.status_code == 200:
print("Server is up and running!")
return process, log_file
except requests.RequestException:
# timeout after 3 minutes (CI can be slow sometimes)
if time.time() - start_time > 180:
log_file.close()
with open("/tmp/server_log.txt", "r") as f:
print("Server log:")
print(f.read())
os.remove("/tmp/server_log.txt")
raise TimeoutError("Server didn't start within 60 seconds")
time.sleep(1)
def stop_server(process, log_file, show=False):
process.terminate()
process.wait()
log_file.close()
if show:
with open("/tmp/server_log.txt", "r") as f:
print("Server log:")
print(f.read())
os.remove("/tmp/server_log.txt")
def get_openapi_json():
response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json")
# error if not 200
response.raise_for_status()
return response.json()
def update_openapi_json(new_data, filename="docs/openapi.json"):
with open(filename, "w") as f:
json.dump(new_data, f, indent=2)
def compare_openapi(old_data, new_data):
differences = []
def compare_recursive(old, new, path=""):
if isinstance(old, dict) and isinstance(new, dict):
for key in set(old.keys()) | set(new.keys()):
new_path = f"{path}.{key}" if path else key
if key not in old:
differences.append(f"Added: {new_path}")
elif key not in new:
differences.append(f"Removed: {new_path}")
else:
compare_recursive(old[key], new[key], new_path)
elif old != new:
differences.append(f"Changed: {path}")
compare_recursive(old_data, new_data)
return differences
def openapi(check: bool):
try:
server_process, log_file = start_server_and_wait()
try:
new_openapi_data = get_openapi_json()
if check:
try:
with open("docs/openapi.json", "r") as f:
old_openapi_data = json.load(f)
except FileNotFoundError:
print(
"docs/openapi.json not found. Run without --check to create it."
)
return
differences = compare_openapi(old_openapi_data, new_openapi_data)
if differences:
print("The following differences were found:")
for diff in differences:
print(diff)
print(
"Please run the script without --check to update the documentation."
)
else:
print("Documentation is up to date.")
else:
update_openapi_json(new_openapi_data)
print("Documentation updated successfully.")
finally:
stop_server(server_process, log_file)
except TimeoutError as e:
print(f"Error: {e}")
raise SystemExit(1)
except requests.RequestException as e:
print(f"Error communicating with the server: {e}")
output = subprocess.check_output(["text-generation-router", "print-schema"])
return json.loads(output)
except subprocess.CalledProcessError as e:
print(f"Error running text-generation-router print-schema: {e}")
raise SystemExit(1)
except json.JSONDecodeError:
print("Error: Invalid JSON received from the server")
raise SystemExit(1)
except Exception as e:
print(f"An unexpected error occurred: {e}")
print("Error: Invalid JSON received from text-generation-router print-schema")
raise SystemExit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Update documentation for text-generation-launcher"
)
subparsers = parser.add_subparsers(dest="command", required=True)
def check_openapi(check: bool):
new_openapi_data = get_openapi_schema()
filename = "docs/openapi.json"
tmp_filename = "openapi_tmp.json"
openapi_parser = subparsers.add_parser(
"openapi", help="Update OpenAPI documentation"
)
openapi_parser.add_argument(
"--check",
action="store_true",
help="Check if the OpenAPI documentation needs updating",
with open(tmp_filename, "w") as f:
json.dump(new_openapi_data, f, indent=2)
if check:
diff = subprocess.run(
["diff", tmp_filename, filename], capture_output=True
).stdout.decode()
os.remove(tmp_filename)
if diff:
print(diff)
raise Exception(
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
)
md_parser = subparsers.add_parser("md", help="Update launcher and supported models")
md_parser.add_argument(
"--check",
action="store_true",
help="Check if the launcher documentation needs updating",
)
return True
else:
os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.")
return True
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
args = parser.parse_args()
if args.command == "openapi":
openapi(args.check)
elif args.command == "md":
check_cli(args.check)
check_supported_models(args.check)
check_openapi(args.check)
if __name__ == "__main__":
main()