diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 5b5cb45e..a612eb6d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -197,6 +197,10 @@ struct Args { #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + /// The IP address to listen on + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + /// The port to listen on. #[clap(default_value = "3000", long, short, env)] port: u16, @@ -874,6 +878,8 @@ fn spawn_webserver( args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), args.max_waiting_tokens.to_string(), + "--hostname".to_string(), + args.hostname.to_string(), "--port".to_string(), args.port.to_string(), "--master-shard-uds-path".to_string(), diff --git a/router/src/main.rs b/router/src/main.rs index f782be09..4cd89d68 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -40,6 +40,8 @@ struct Args { max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] @@ -82,6 +84,7 @@ fn main() -> Result<(), std::io::Error> { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + hostname, port, master_shard_uds_path, tokenizer_name, @@ -213,8 +216,13 @@ fn main() -> Result<(), std::io::Error> { .expect("Unable to warmup model"); tracing::info!("Connected"); - // Binds on localhost - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; // Run server server::run(