feat: add pre commit step to force schema update when router changes

This commit is contained in:
drbh 2024-07-02 14:47:01 +00:00
parent 0759ec495e
commit 7b34ba3408
4 changed files with 137 additions and 89 deletions

View File

@ -16,3 +16,40 @@ repos:
- id: fmt
- id: cargo-check
- id: clippy
- repo: local
hooks:
- id: check-openapi-update
name: check openapi spec update
entry: python
language: system
pass_filenames: false
always_run: true
args:
- -c
- |
import os
import sys
import subprocess
def get_changed_files():
result = subprocess.run(['git', 'diff', '--name-only'], capture_output=True, text=True)
return result.stdout.splitlines()
changed_files = get_changed_files()
router_files = [f for f in changed_files if f.startswith('router/')]
if not router_files:
print("No router files changed. Skipping OpenAPI spec check.")
sys.exit(0)
openapi_file = 'docs/openapi.json'
if not os.path.exists(openapi_file):
print(f"Error: {openapi_file} does not exist.")
sys.exit(1)
if openapi_file not in changed_files:
print(f"Error: Router files were updated, but {openapi_file} was not updated.")
sys.exit(1)
else:
print(f"{openapi_file} has been updated along with router changes.")
sys.exit(0)

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "2.0.1"
"version": "2.1.1-dev0"
},
"paths": {
"/": {
@ -19,7 +19,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"operationId": "compat_generate",
"requestBody": {
"content": {
@ -108,7 +107,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "generate",
"requestBody": {
"content": {
@ -192,7 +190,6 @@
"Text Generation Inference"
],
"summary": "Generate a stream of token using Server-Sent Events",
"description": "Generate a stream of token using Server-Sent Events",
"operationId": "generate_stream",
"requestBody": {
"content": {
@ -276,7 +273,6 @@
"Text Generation Inference"
],
"summary": "Health check method",
"description": "Health check method",
"operationId": "health",
"responses": {
"200": {
@ -305,7 +301,6 @@
"Text Generation Inference"
],
"summary": "Text Generation Inference endpoint info",
"description": "Text Generation Inference endpoint info",
"operationId": "get_model_info",
"responses": {
"200": {
@ -327,7 +322,6 @@
"Text Generation Inference"
],
"summary": "Prometheus metrics scrape endpoint",
"description": "Prometheus metrics scrape endpoint",
"operationId": "metrics",
"responses": {
"200": {
@ -349,7 +343,6 @@
"Text Generation Inference"
],
"summary": "Tokenize inputs",
"description": "Tokenize inputs",
"operationId": "tokenize",
"requestBody": {
"content": {
@ -394,7 +387,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "chat_completions",
"requestBody": {
"content": {
@ -483,7 +475,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "completions",
"requestBody": {
"content": {
@ -626,7 +617,6 @@
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
@ -653,9 +643,6 @@
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
},
@ -697,7 +684,6 @@
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
@ -723,9 +709,6 @@
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
}
@ -756,34 +739,19 @@
"nullable": true
},
"message": {
"$ref": "#/components/schemas/Message"
"$ref": "#/components/schemas/OutputMessage"
}
}
},
"ChatCompletionDelta": {
"type": "object",
"required": [
"role"
],
"properties": {
"content": {
"type": "string",
"example": "What is Deep Learning?",
"nullable": true
"oneOf": [
{
"$ref": "#/components/schemas/TextMessage"
},
"role": {
"type": "string",
"example": "user"
},
"tool_calls": {
"allOf": [
{
"$ref": "#/components/schemas/DeltaToolCall"
}
],
"nullable": true
{
"$ref": "#/components/schemas/ToolCallDelta"
}
}
]
},
"ChatCompletionLogprob": {
"type": "object",
@ -903,6 +871,15 @@
"example": 0.1,
"nullable": true
},
"response_format": {
"allOf": [
{
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "null",
"nullable": true
},
"seed": {
"type": "integer",
"format": "int64",
@ -1021,7 +998,6 @@
"type": "object",
"required": [
"id",
"object",
"created",
"choices",
"model",
@ -1045,9 +1021,6 @@
"model": {
"type": "string"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
}
@ -1081,12 +1054,7 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"prompt": {
"type": "array",
"items": {
"type": "string"
},
"description": "The prompt to generate completions for.",
"example": "What is Deep Learning?"
"$ref": "#/components/schemas/Prompt"
},
"repetition_penalty": {
"type": "number",
@ -1100,6 +1068,15 @@
"nullable": true,
"minimum": 0
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
},
"stream": {
"type": "boolean"
},
@ -1121,15 +1098,6 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95,
"nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
}
}
},
@ -1272,8 +1240,16 @@
"GenerateParameters": {
"type": "object",
"properties": {
"adapter_id": {
"type": "string",
"description": "Lora adapter id",
"default": "null",
"example": "null",
"nullable": true
},
"best_of": {
"type": "integer",
"description": "Generate best_of sequences and return the one if the highest token logprobs.",
"default": "null",
"example": 1,
"nullable": true,
@ -1282,20 +1258,24 @@
},
"decoder_input_details": {
"type": "boolean",
"description": "Whether to return decoder input token logprobs and ids.",
"default": "false"
},
"details": {
"type": "boolean",
"description": "Whether to return generation details.",
"default": "true"
},
"do_sample": {
"type": "boolean",
"description": "Activate logits sampling.",
"default": "false",
"example": true
},
"frequency_penalty": {
"type": "number",
"format": "float",
"description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
"default": "null",
"example": 0.1,
"nullable": true,
@ -1313,6 +1293,7 @@
"max_new_tokens": {
"type": "integer",
"format": "int32",
"description": "Maximum number of tokens to generate.",
"default": "100",
"example": "20",
"nullable": true,
@ -1321,6 +1302,7 @@
"repetition_penalty": {
"type": "number",
"format": "float",
"description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.",
"default": "null",
"example": 1.03,
"nullable": true,
@ -1328,6 +1310,7 @@
},
"return_full_text": {
"type": "boolean",
"description": "Whether to prepend the prompt to the generated text",
"default": "null",
"example": false,
"nullable": true
@ -1335,6 +1318,7 @@
"seed": {
"type": "integer",
"format": "int64",
"description": "Random sampling seed.",
"default": "null",
"example": "null",
"nullable": true,
@ -1346,6 +1330,7 @@
"items": {
"type": "string"
},
"description": "Stop generating tokens if a member of `stop` is generated.",
"example": [
"photographer"
],
@ -1354,6 +1339,7 @@
"temperature": {
"type": "number",
"format": "float",
"description": "The value used to module the logits distribution.",
"default": "null",
"example": 0.5,
"nullable": true,
@ -1362,6 +1348,7 @@
"top_k": {
"type": "integer",
"format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
"default": "null",
"example": 10,
"nullable": true,
@ -1370,6 +1357,7 @@
"top_n_tokens": {
"type": "integer",
"format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.",
"default": "null",
"example": 5,
"nullable": true,
@ -1379,6 +1367,7 @@
"top_p": {
"type": "number",
"format": "float",
"description": "Top-p value for nucleus sampling.",
"default": "null",
"example": 0.95,
"nullable": true,
@ -1387,6 +1376,7 @@
},
"truncate": {
"type": "integer",
"description": "Truncate inputs tokens to the given size.",
"default": "null",
"example": "null",
"nullable": true,
@ -1395,6 +1385,7 @@
"typical_p": {
"type": "number",
"format": "float",
"description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.",
"default": "null",
"example": 0.95,
"nullable": true,
@ -1403,6 +1394,7 @@
},
"watermark": {
"type": "boolean",
"description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).",
"default": "false",
"example": true
}
@ -1495,13 +1487,14 @@
"max_concurrent_requests",
"max_best_of",
"max_stop_sequences",
"max_input_length",
"max_input_tokens",
"max_total_tokens",
"waiting_served_ratio",
"max_batch_total_tokens",
"max_waiting_tokens",
"validation_workers",
"max_client_batch_size",
"router",
"version"
],
"properties": {
@ -1538,7 +1531,7 @@
"example": "128",
"minimum": 0
},
"max_input_length": {
"max_input_tokens": {
"type": "integer",
"example": "1024",
"minimum": 0
@ -1581,6 +1574,11 @@
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
"nullable": true
},
"router": {
"type": "string",
"description": "Router Info",
"example": "text-generation-router"
},
"sha": {
"type": "string",
"example": "null",
@ -1593,7 +1591,6 @@
},
"version": {
"type": "string",
"description": "Router Info",
"example": "0.5.0"
},
"waiting_served_ratio": {
@ -1606,13 +1603,12 @@
"Message": {
"type": "object",
"required": [
"role"
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "My name is David and I",
"nullable": true
"$ref": "#/components/schemas/MessageContent"
},
"name": {
"type": "string",
@ -1622,13 +1618,6 @@
"role": {
"type": "string",
"example": "user"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
},
"nullable": true
}
}
},
@ -1817,9 +1806,7 @@
"$ref": "#/components/schemas/FunctionDefinition"
},
"id": {
"type": "integer",
"format": "int32",
"minimum": 0
"type": "string"
},
"type": {
"type": "string"
@ -1830,20 +1817,22 @@
"oneOf": [
{
"type": "object",
"required": [
"FunctionName"
],
"properties": {
"FunctionName": {
"type": "string"
}
}
"default": null,
"nullable": true
},
{
"type": "string",
"enum": [
"OneOf"
]
"type": "string"
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
]
},

View File

@ -83,6 +83,8 @@ struct Args {
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env, default_value_t = false)]
update_openapi_schema: bool,
}
#[tokio::main]
@ -119,6 +121,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
update_openapi_schema,
} = args;
// Launch Tokio runtime
@ -388,6 +391,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
update_openapi_schema,
)
.await?;
Ok(())

View File

@ -1430,6 +1430,7 @@ pub async fn run(
messages_api_enabled: bool,
grammar_support: bool,
max_client_batch_size: usize,
update_openapi_schema: bool,
) -> Result<(), WebServerError> {
// OpenAPI documentation
#[derive(OpenApi)]
@ -1499,7 +1500,24 @@ pub async fn run(
)]
struct ApiDoc;
// Create state
if update_openapi_schema {
use std::io::Write;
let cargo_workspace =
std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
println!("workspace {}", cargo_workspace);
let output_file = format!("{}/../docs/openapi.json", cargo_workspace);
println!("output file {}", output_file);
let openapi = ApiDoc::openapi();
let mut file = std::fs::File::create(output_file).expect("Unable to create file");
file.write_all(
openapi
.to_pretty_json()
.expect("Unable to serialize OpenAPI")
.as_bytes(),
)
.expect("Unable to write data");
}
// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (