FInal touches.

This commit is contained in:
Nicolas Patry 2023-06-30 08:39:15 +00:00
parent 59474c29aa
commit 02e43ccf6f
2 changed files with 11 additions and 1 deletions

View File

@ -289,6 +289,7 @@ fn shard_manager(
model_id: String, model_id: String,
revision: Option<String>, revision: Option<String>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
dtype: Option<Dtype>,
trust_remote_code: bool, trust_remote_code: bool,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
@ -338,6 +339,11 @@ fn shard_manager(
shard_argv.push(quantize.to_string()) shard_argv.push(quantize.to_string())
} }
if let Some(dtype) = dtype {
shard_argv.push("--dtype".to_string());
shard_argv.push(dtype.to_string())
}
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_argv.push("--revision".to_string()); shard_argv.push("--revision".to_string());
@ -768,6 +774,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize; let quantize = args.quantize;
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port; let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
@ -778,6 +785,7 @@ fn spawn_shards(
model_id, model_id,
revision, revision,
quantize, quantize,
dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
rank, rank,

View File

@ -156,4 +156,6 @@ def serve(
logger.info("Signal received. Shutting down") logger.info("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code)) asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
)