This commit is contained in:
OlivierDehaene 2023-04-26 20:03:46 +02:00
parent 6d8d5b6d1d
commit 018e87d78d
2 changed files with 12 additions and 11 deletions

View File

@ -493,6 +493,7 @@ fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherE
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn spawn_shards(
num_shard: usize,
args: &Args,
@ -515,11 +516,11 @@ fn spawn_shards(
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize.clone();
let master_port = args.master_port.clone();
let disable_custom_kernels = args.disable_custom_kernels.clone();
let watermark_gamma = args.watermark_gamma.clone();
let watermark_delta = args.watermark_delta.clone();
let quantize = args.quantize;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
thread::spawn(move || {
shard_manager(
model_id,
@ -559,12 +560,12 @@ fn spawn_shards(
}
Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err);
shutdown_shards(shutdown, &shutdown_receiver);
shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart);
}
Err(TryRecvError::Disconnected) => {
tracing::error!("Shard status channel disconnected");
shutdown_shards(shutdown, &shutdown_receiver);
shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardDisconnected);
}
}
@ -666,7 +667,7 @@ fn spawn_webserver(
tracing::error!("{}", err);
}
shutdown_shards(shutdown, &shutdown_receiver);
shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::WebserverCannotStart);
}
};

View File

@ -551,8 +551,8 @@ pub async fn run(
max_input_length,
max_total_tokens,
);
let healthy = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), healthy.clone());
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
let infer = Infer::new(
client,
validation,
@ -561,7 +561,7 @@ pub async fn run(
max_waiting_tokens,
max_concurrent_requests,
shard_info.requires_padding,
healthy,
generation_health,
);
// Duration buckets