Adding new env variables for TPU backends.

This commit is contained in:
Nicolas Patry 2024-04-17 10:00:58 +00:00
parent 06c3d4b1ec
commit 5bc3d65dd3

View File

@ -448,6 +448,8 @@ fn shard_manager(
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
max_total_tokens: usize,
max_batch_size: Option<usize>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
@ -512,6 +514,7 @@ fn shard_manager(
(Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
@ -564,6 +567,14 @@ fn shard_manager(
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}
envs.push((
"MAX_TOTAL_TOKENS".into(),
max_total_tokens.to_string().into(),
));
if let Some(max_batch_size) = max_batch_size {
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -967,6 +978,7 @@ fn spawn_shards(
num_shard: usize,
args: &Args,
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>,
@ -998,6 +1010,7 @@ fn spawn_shards(
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
thread::spawn(move || {
shard_manager(
model_id,
@ -1020,6 +1033,8 @@ fn spawn_shards(
cuda_memory_fraction,
rope_scaling,
rope_factor,
max_total_tokens,
max_batch_size,
otlp_endpoint,
status_sender,
shutdown,
@ -1473,6 +1488,7 @@ fn main() -> Result<(), LauncherError> {
num_shard,
&args,
cuda_graphs,
max_total_tokens,
shutdown.clone(),
&shutdown_receiver,
shutdown_sender,