diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 700f763e..24d1d748 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -112,7 +112,7 @@ def serve( logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - if sharded: + if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: tgi_file = Path(__file__).resolve().parent / "tgi_service.py" num_shard = int(os.getenv("WORLD_SIZE", "1")) logger.info("CLI SHARDED = {}".format(num_shard)) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 321d7c69..d9c41346 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1532,7 +1532,10 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 && env_runtime::Env::new().is_hpu_device() { + if rank != 0 + && env_runtime::Env::new().is_hpu_device() + && std::env::var("ATTENTION").as_deref() != Ok("paged") + { tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); break; }