feat: support google feature and read env vars

This commit is contained in:
drbh 2024-01-17 18:57:43 -05:00
parent f4fd89b224
commit 5e38c4bfda
2 changed files with 10 additions and 35 deletions

View File

@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
[features]
default = ["ngrok"]
ngrok = ["dep:ngrok"]
google = []

View File

@ -21,36 +21,6 @@ use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
#[allow(dead_code)] // many of the fields are not used
#[derive(Debug)]
struct VertexAIConfig {
aip_http_port: u16,
aip_predict_route: String,
aip_health_route: String,
}
impl VertexAIConfig {
fn new(aip_http_port: u16, aip_predict_route: String, aip_health_route: String) -> Self {
Self {
aip_http_port,
aip_predict_route,
aip_health_route,
}
}
fn to_env(&self) {
// NOTE: this will only set the values for this process
// NOTE: child processes cannot set env vars for their parents
// TODO: find a way to set the values for the whole system
// - maybe write to a file
// - maybe use a shell script to set the values
// - maybe these values are set upstream (before this process is started)
// - if set upstream maybe we read in; if we need them?
std::env::set_var("AIP_HTTP_PORT", self.aip_http_port.to_string());
std::env::set_var("AIP_PREDICT_ROUTE", self.aip_predict_route.clone());
std::env::set_var("AIP_HEALTH_ROUTE", self.aip_health_route.clone());
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -143,11 +113,6 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
} = args;
// Set Vertex AI config and update the env
let vertex_ai_config =
VertexAIConfig::new(args.port, "/vertex".to_string(), "/health".to_string());
vertex_ai_config.to_env();
// Launch Tokio runtime
init_logging(otlp_endpoint, json_output);
@ -363,6 +328,15 @@ async fn main() -> Result<(), RouterError> {
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected");
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {