Adding new env variables for TPU backends.

This commit is contained in:
Nicolas Patry 2024-04-17 10:00:58 +00:00
parent 06c3d4b1ec
commit 5bc3d65dd3

View File

@ -448,6 +448,8 @@ fn shard_manager(
cuda_memory_fraction: f32, cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>, rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>, rope_factor: Option<f32>,
max_total_tokens: usize,
max_batch_size: Option<usize>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
@ -512,6 +514,7 @@ fn shard_manager(
(Some(scaling), Some(factor)) => Some((scaling, factor)), (Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)), (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
}; };
// OpenTelemetry // OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint { if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string()); shard_args.push("--otlp-endpoint".to_string());
@ -564,6 +567,14 @@ fn shard_manager(
envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
} }
envs.push((
"MAX_TOTAL_TOKENS".into(),
max_total_tokens.to_string().into(),
));
if let Some(max_batch_size) = max_batch_size {
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -967,6 +978,7 @@ fn spawn_shards(
num_shard: usize, num_shard: usize,
args: &Args, args: &Args,
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>, shutdown_sender: mpsc::Sender<()>,
@ -998,6 +1010,7 @@ fn spawn_shards(
let cuda_memory_fraction = args.cuda_memory_fraction; let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling; let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
@ -1020,6 +1033,8 @@ fn spawn_shards(
cuda_memory_fraction, cuda_memory_fraction,
rope_scaling, rope_scaling,
rope_factor, rope_factor,
max_total_tokens,
max_batch_size,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,
@ -1473,6 +1488,7 @@ fn main() -> Result<(), LauncherError> {
num_shard, num_shard,
&args, &args,
cuda_graphs, cuda_graphs,
max_total_tokens,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
shutdown_sender, shutdown_sender,