Make bf16 default for hpu, fix script (#205)

This commit is contained in:
Abhilash Majumder 2024-08-11 14:18:35 +05:30 committed by GitHub
parent cf2ff5a1dd
commit d403575c43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -182,7 +182,7 @@ struct Args {
speculate: Option<usize>, speculate: Option<usize>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`. /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)] #[clap(default_value = "bfloat16", long, env, value_enum)]
dtype: Option<Dtype>, dtype: Option<Dtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
@ -501,13 +501,12 @@ fn shard_manager(
if let Some(dtype) = dtype { if let Some(dtype) = dtype {
shard_args.push("--dtype".to_string()); shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string()) shard_args.push(dtype.to_string());
} }
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_args.push("--revision".to_string()); shard_args.push("--revision".to_string());
shard_args.push(revision) shard_args.push(revision);
} }
let rope = match (rope_scaling, rope_factor) { let rope = match (rope_scaling, rope_factor) {