From 432961324e7f8d5feedf666c9753ad1fd061399f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 26 Apr 2024 15:44:44 +0200 Subject: [PATCH] Adding new env variables for TPU backends. (#1755) # What does this PR do? On TPU (and probably inferentia). The model needs to know right off the bat about BATCH_SIZE and MAX_TOTAL_TOKENS (since the entire cache will be determined by both). This PR sends that information to the shards to they can allocate accordingly. Should be no-op for other backends. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f4f2c533..01c60bb4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -435,7 +435,6 @@ fn shard_manager( quantize: Option, speculate: Option, dtype: Option, - max_total_tokens: usize, trust_remote_code: bool, uds_path: String, rank: usize, @@ -451,6 +450,8 @@ fn shard_manager( cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, + max_total_tokens: usize, + max_batch_size: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -515,6 +516,7 @@ fn shard_manager( (Some(scaling), Some(factor)) => Some((scaling, factor)), (None, Some(factor)) => Some((RopeScaling::Linear, factor)), }; + // OpenTelemetry if let Some(otlp_endpoint) = otlp_endpoint { shard_args.push("--otlp-endpoint".to_string()); @@ -527,9 +529,6 @@ fn shard_manager( // Remove LOG_LEVEL if present envs.retain(|(name, _)| name != "LOG_LEVEL"); - // Max total tokens - envs.push(("MAX_TOTAL_TOKENS".into(), max_total_tokens.to_string().into())); - // Torch Distributed Env vars if world_size == 1 { envs.push(("RANK".into(), rank.to_string().into())); @@ -572,6 +571,14 @@ fn shard_manager( envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); } + envs.push(( + "MAX_TOTAL_TOKENS".into(), + max_total_tokens.to_string().into(), + )); + if let Some(max_batch_size) = max_batch_size { + envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -975,13 +982,13 @@ fn spawn_shards( num_shard: usize, args: &Args, cuda_graphs: Vec, + max_total_tokens: usize, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, status_receiver: &mpsc::Receiver, status_sender: mpsc::Sender, running: Arc, - max_total_tokens: usize, ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..1 { @@ -1007,6 +1014,7 @@ fn spawn_shards( let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; + let max_batch_size = args.max_batch_size; thread::spawn(move || { shard_manager( model_id, @@ -1014,7 +1022,6 @@ fn spawn_shards( quantize, speculate, dtype, - max_total_tokens, trust_remote_code, uds_path, rank, @@ -1030,6 +1037,8 @@ fn spawn_shards( cuda_memory_fraction, rope_scaling, rope_factor, + max_total_tokens, + max_batch_size, otlp_endpoint, status_sender, shutdown, @@ -1494,13 +1503,13 @@ fn main() -> Result<(), LauncherError> { num_shard, &args, cuda_graphs, + max_total_tokens, shutdown.clone(), &shutdown_receiver, shutdown_sender, &status_receiver, status_sender, running.clone(), - max_total_tokens, )?; // We might have received a termination signal