diff --git a/router/Cargo.toml b/router/Cargo.toml index 1a7ceb70..7d6dc017 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } [features] default = ["ngrok"] ngrok = ["dep:ngrok"] +google = [] diff --git a/router/src/main.rs b/router/src/main.rs index 65802d5c..60a66a41 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -21,36 +21,6 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; -#[allow(dead_code)] // many of the fields are not used -#[derive(Debug)] -struct VertexAIConfig { - aip_http_port: u16, - aip_predict_route: String, - aip_health_route: String, -} - -impl VertexAIConfig { - fn new(aip_http_port: u16, aip_predict_route: String, aip_health_route: String) -> Self { - Self { - aip_http_port, - aip_predict_route, - aip_health_route, - } - } - fn to_env(&self) { - // NOTE: this will only set the values for this process - // NOTE: child processes cannot set env vars for their parents - // TODO: find a way to set the values for the whole system - // - maybe write to a file - // - maybe use a shell script to set the values - // - maybe these values are set upstream (before this process is started) - // - if set upstream maybe we read in; if we need them? - std::env::set_var("AIP_HTTP_PORT", self.aip_http_port.to_string()); - std::env::set_var("AIP_PREDICT_ROUTE", self.aip_predict_route.clone()); - std::env::set_var("AIP_HEALTH_ROUTE", self.aip_health_route.clone()); - } -} - /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -143,11 +113,6 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, } = args; - // Set Vertex AI config and update the env - let vertex_ai_config = - VertexAIConfig::new(args.port, "/vertex".to_string(), "/health".to_string()); - vertex_ai_config.to_env(); - // Launch Tokio runtime init_logging(otlp_endpoint, json_output); @@ -363,6 +328,15 @@ async fn main() -> Result<(), RouterError> { tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Connected"); + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + let addr = match hostname.parse() { Ok(ip) => SocketAddr::new(ip, port), Err(_) => {