misc(lint): make clippy happier

This commit is contained in:
Morgan Funtowicz 2024-11-03 14:26:57 +01:00
parent 31d9254776
commit 188442f67d
3 changed files with 21 additions and 53 deletions

36
Cargo.lock generated
View File

@ -4239,7 +4239,7 @@ dependencies = [
"tracing", "tracing",
"tracing-opentelemetry 0.27.0", "tracing-opentelemetry 0.27.0",
"tracing-subscriber", "tracing-subscriber",
"utoipa 5.1.2", "utoipa",
] ]
[[package]] [[package]]
@ -4368,7 +4368,7 @@ dependencies = [
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"ureq", "ureq",
"utoipa 4.2.3", "utoipa",
"utoipa-swagger-ui", "utoipa-swagger-ui",
"uuid", "uuid",
"vergen", "vergen",
@ -4419,7 +4419,7 @@ dependencies = [
"tracing", "tracing",
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"utoipa 4.2.3", "utoipa",
"utoipa-swagger-ui", "utoipa-swagger-ui",
] ]
@ -4470,7 +4470,7 @@ dependencies = [
"tracing", "tracing",
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"utoipa 4.2.3", "utoipa",
"utoipa-swagger-ui", "utoipa-swagger-ui",
] ]
@ -5192,19 +5192,7 @@ dependencies = [
"indexmap 2.6.0", "indexmap 2.6.0",
"serde", "serde",
"serde_json", "serde_json",
"utoipa-gen 4.3.0", "utoipa-gen",
]
[[package]]
name = "utoipa"
version = "5.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e12e84f0ff45b6818029cd0f67280e453c80132c1b9897df407ecc20b9f7cfd"
dependencies = [
"indexmap 2.5.0",
"serde",
"serde_json",
"utoipa-gen 5.1.2",
] ]
[[package]] [[package]]
@ -5220,18 +5208,6 @@ dependencies = [
"syn 2.0.85", "syn 2.0.85",
] ]
[[package]]
name = "utoipa-gen"
version = "5.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dfc694d3a3118d2b9e80d68be83bf1aab7988510916934db83da61c14e7e6b2"
dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.79",
]
[[package]] [[package]]
name = "utoipa-swagger-ui" name = "utoipa-swagger-ui"
version = "6.0.0" version = "6.0.0"
@ -5244,7 +5220,7 @@ dependencies = [
"rust-embed", "rust-embed",
"serde", "serde",
"serde_json", "serde_json",
"utoipa 4.2.3", "utoipa",
"zip", "zip",
] ]

View File

@ -22,7 +22,7 @@ tokenizers = { workspace = true }
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.27.0" tracing-opentelemetry = "0.27.0"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
utoipa = { version = "5.1.2", features = ["axum_extras"] } utoipa = { version = "4.2.3", features = ["axum_extras"] }
log = "0.4.22" log = "0.4.22"
[build-dependencies] [build-dependencies]

View File

@ -1,6 +1,7 @@
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use std::path::PathBuf; use std::path::PathBuf;
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError}; use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
use text_generation_router::server::ApiDoc;
use text_generation_router::{server, usage_stats}; use text_generation_router::{server, usage_stats};
use thiserror::Error; use thiserror::Error;
@ -35,13 +36,8 @@ struct Args {
port: u16, port: u16,
#[clap(long, env, help = "Path to GGUF model file(s) to load")] #[clap(long, env, help = "Path to GGUF model file(s) to load")]
gguf_path: PathBuf, gguf_path: PathBuf,
#[clap( #[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
long, num_model_instance: u16,
env,
default_value = "1",
help = "Number of CPU threads allocated to one llama.cpp model"
)]
cores_per_instance: u16,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
#[clap(long, env)] #[clap(long, env)]
@ -67,8 +63,6 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
@ -100,7 +94,7 @@ async fn main() -> Result<(), RouterError> {
hostname, hostname,
port, port,
gguf_path, gguf_path,
cores_per_instance, num_model_instance,
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
@ -113,19 +107,17 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
} = args; } = args;
// if let Some(Commands::PrintSchema) = command { if let Some(Commands::PrintSchema) = command {
// use utoipa::OpenApi; use utoipa::OpenApi;
// let api_doc = ApiDoc::openapi(); let api_doc = ApiDoc::openapi().to_pretty_json().unwrap();
// let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); println!("{}", api_doc);
// println!("{}", api_doc); std::process::exit(0);
// std::process::exit(0); };
// };
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args // Validate args
@ -144,11 +136,11 @@ async fn main() -> Result<(), RouterError> {
)); ));
} }
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { if let Some(max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens { if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
} }
if max_total_tokens as u32 > *max_batch_total_tokens { if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
} }
} }
@ -177,13 +169,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
false,
hostname, hostname,
port, port,
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,