mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: improve update doc and add command to print router schema
This commit is contained in:
parent
caa44012ad
commit
e4161a185f
13
.github/workflows/autodocs.yaml
vendored
13
.github/workflows/autodocs.yaml
vendored
@ -20,16 +20,9 @@ jobs:
|
|||||||
echo text-generation-launcher --help
|
echo text-generation-launcher --help
|
||||||
python update_doc.py md --check
|
python update_doc.py md --check
|
||||||
|
|
||||||
- name: Install Protoc
|
- name: Install router
|
||||||
uses: arduino/setup-protoc@v1
|
id: install-router
|
||||||
- name: Clean unused files
|
run: cargo install --path router/
|
||||||
run: |
|
|
||||||
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
|
|
||||||
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
|
||||||
|
|
||||||
- name: Install
|
|
||||||
run: |
|
|
||||||
make install-cpu
|
|
||||||
|
|
||||||
- name: Check that openapi schema is up-to-date
|
- name: Check that openapi schema is up-to-date
|
||||||
run: |
|
run: |
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use clap::Subcommand;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
@ -85,10 +89,15 @@ struct Args {
|
|||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
PrintSchema,
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
|
command,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
let print_schema_command = match command {
|
||||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
Some(Commands::PrintSchema) => true,
|
||||||
|
None => {
|
||||||
|
// only init logging if we are not running the print schema command
|
||||||
|
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_tokens >= max_total_tokens {
|
if max_input_tokens >= max_total_tokens {
|
||||||
@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
|
print_schema_command,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1387,10 +1387,10 @@ async fn tokenize(
|
|||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/metrics",
|
path = "/metrics",
|
||||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||||
)]
|
)]
|
||||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
@ -1430,6 +1430,7 @@ pub async fn run(
|
|||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
grammar_support: bool,
|
grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
print_schema_command: bool,
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -1500,6 +1501,12 @@ pub async fn run(
|
|||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
|
if print_schema_command {
|
||||||
|
let api_doc = ApiDoc::openapi();
|
||||||
|
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||||
|
println!("{}", api_doc);
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
// Open connection, get model info and warmup
|
// Open connection, get model info and warmup
|
||||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||||
|
184
update_doc.py
184
update_doc.py
@ -1,9 +1,7 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import requests
|
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
TEMPLATE = """
|
TEMPLATE = """
|
||||||
@ -126,155 +124,55 @@ def check_supported_models(check: bool):
|
|||||||
f.write(final_doc)
|
f.write(final_doc)
|
||||||
|
|
||||||
|
|
||||||
def start_server_and_wait():
|
def get_openapi_schema():
|
||||||
log_file = open("/tmp/server_log.txt", "w")
|
|
||||||
|
|
||||||
process = subprocess.Popen(
|
|
||||||
["text-generation-launcher"],
|
|
||||||
stdout=log_file,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
universal_newlines=True,
|
|
||||||
)
|
|
||||||
print("Server is starting...")
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
response = requests.get("http://127.0.0.1:3000/health")
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("Server is up and running!")
|
|
||||||
return process, log_file
|
|
||||||
except requests.RequestException:
|
|
||||||
# timeout after 3 minutes (CI can be slow sometimes)
|
|
||||||
if time.time() - start_time > 180:
|
|
||||||
log_file.close()
|
|
||||||
with open("/tmp/server_log.txt", "r") as f:
|
|
||||||
print("Server log:")
|
|
||||||
print(f.read())
|
|
||||||
os.remove("/tmp/server_log.txt")
|
|
||||||
raise TimeoutError("Server didn't start within 60 seconds")
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
|
|
||||||
def stop_server(process, log_file, show=False):
|
|
||||||
process.terminate()
|
|
||||||
process.wait()
|
|
||||||
log_file.close()
|
|
||||||
|
|
||||||
if show:
|
|
||||||
with open("/tmp/server_log.txt", "r") as f:
|
|
||||||
print("Server log:")
|
|
||||||
print(f.read())
|
|
||||||
os.remove("/tmp/server_log.txt")
|
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_json():
|
|
||||||
response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json")
|
|
||||||
# error if not 200
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
def update_openapi_json(new_data, filename="docs/openapi.json"):
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(new_data, f, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def compare_openapi(old_data, new_data):
|
|
||||||
differences = []
|
|
||||||
|
|
||||||
def compare_recursive(old, new, path=""):
|
|
||||||
if isinstance(old, dict) and isinstance(new, dict):
|
|
||||||
for key in set(old.keys()) | set(new.keys()):
|
|
||||||
new_path = f"{path}.{key}" if path else key
|
|
||||||
if key not in old:
|
|
||||||
differences.append(f"Added: {new_path}")
|
|
||||||
elif key not in new:
|
|
||||||
differences.append(f"Removed: {new_path}")
|
|
||||||
else:
|
|
||||||
compare_recursive(old[key], new[key], new_path)
|
|
||||||
elif old != new:
|
|
||||||
differences.append(f"Changed: {path}")
|
|
||||||
|
|
||||||
compare_recursive(old_data, new_data)
|
|
||||||
return differences
|
|
||||||
|
|
||||||
|
|
||||||
def openapi(check: bool):
|
|
||||||
try:
|
try:
|
||||||
server_process, log_file = start_server_and_wait()
|
output = subprocess.check_output(["text-generation-router", "print-schema"])
|
||||||
|
return json.loads(output)
|
||||||
try:
|
except subprocess.CalledProcessError as e:
|
||||||
new_openapi_data = get_openapi_json()
|
print(f"Error running text-generation-router print-schema: {e}")
|
||||||
|
|
||||||
if check:
|
|
||||||
try:
|
|
||||||
with open("docs/openapi.json", "r") as f:
|
|
||||||
old_openapi_data = json.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(
|
|
||||||
"docs/openapi.json not found. Run without --check to create it."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
differences = compare_openapi(old_openapi_data, new_openapi_data)
|
|
||||||
|
|
||||||
if differences:
|
|
||||||
print("The following differences were found:")
|
|
||||||
for diff in differences:
|
|
||||||
print(diff)
|
|
||||||
print(
|
|
||||||
"Please run the script without --check to update the documentation."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("Documentation is up to date.")
|
|
||||||
else:
|
|
||||||
update_openapi_json(new_openapi_data)
|
|
||||||
print("Documentation updated successfully.")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
stop_server(server_process, log_file)
|
|
||||||
|
|
||||||
except TimeoutError as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
raise SystemExit(1)
|
|
||||||
except requests.RequestException as e:
|
|
||||||
print(f"Error communicating with the server: {e}")
|
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print("Error: Invalid JSON received from the server")
|
print("Error: Invalid JSON received from text-generation-router print-schema")
|
||||||
raise SystemExit(1)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An unexpected error occurred: {e}")
|
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def check_openapi(check: bool):
|
||||||
parser = argparse.ArgumentParser(
|
new_openapi_data = get_openapi_schema()
|
||||||
description="Update documentation for text-generation-launcher"
|
filename = "docs/openapi.json"
|
||||||
)
|
tmp_filename = "openapi_tmp.json"
|
||||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
||||||
|
|
||||||
openapi_parser = subparsers.add_parser(
|
with open(tmp_filename, "w") as f:
|
||||||
"openapi", help="Update OpenAPI documentation"
|
json.dump(new_openapi_data, f, indent=2)
|
||||||
)
|
|
||||||
openapi_parser.add_argument(
|
|
||||||
"--check",
|
|
||||||
action="store_true",
|
|
||||||
help="Check if the OpenAPI documentation needs updating",
|
|
||||||
)
|
|
||||||
|
|
||||||
md_parser = subparsers.add_parser("md", help="Update launcher and supported models")
|
if check:
|
||||||
md_parser.add_argument(
|
diff = subprocess.run(
|
||||||
"--check",
|
["diff", tmp_filename, filename], capture_output=True
|
||||||
action="store_true",
|
).stdout.decode()
|
||||||
help="Check if the launcher documentation needs updating",
|
os.remove(tmp_filename)
|
||||||
)
|
|
||||||
|
if diff:
|
||||||
|
print(diff)
|
||||||
|
raise Exception(
|
||||||
|
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
os.rename(tmp_filename, filename)
|
||||||
|
print("OpenAPI documentation updated.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--check", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.command == "openapi":
|
check_cli(args.check)
|
||||||
openapi(args.check)
|
check_supported_models(args.check)
|
||||||
elif args.command == "md":
|
check_openapi(args.check)
|
||||||
check_cli(args.check)
|
|
||||||
check_supported_models(args.check)
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user