From 02e43ccf6f0e9d3adf5e2fd9685478ff0bb662bb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 30 Jun 2023 08:39:15 +0000 Subject: [PATCH] FInal touches. --- launcher/src/main.rs | 8 ++++++++ server/text_generation_server/server.py | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 1e082ca1..9d1932cc 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -289,6 +289,7 @@ fn shard_manager( model_id: String, revision: Option, quantize: Option, + dtype: Option, trust_remote_code: bool, uds_path: String, rank: usize, @@ -338,6 +339,11 @@ fn shard_manager( shard_argv.push(quantize.to_string()) } + if let Some(dtype) = dtype { + shard_argv.push("--dtype".to_string()); + shard_argv.push(dtype.to_string()) + } + // Model optional revision if let Some(revision) = revision { shard_argv.push("--revision".to_string()); @@ -768,6 +774,7 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; + let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; let disable_custom_kernels = args.disable_custom_kernels; @@ -778,6 +785,7 @@ fn spawn_shards( model_id, revision, quantize, + dtype, trust_remote_code, uds_path, rank, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 5d2702d0..e59cc108 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -156,4 +156,6 @@ def serve( logger.info("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code)) + asyncio.run( + serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + )