fixed clap arg parsing.

This commit is contained in:
marcusdunn 2023-08-15 14:03:27 -07:00
parent 9b3be8f79b
commit ed2efe3dd9

View File

@ -96,7 +96,7 @@ struct Args {
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(default_values_t="vec![]", long, env)]
#[clap(long, env, value_parser=parse_key_val::<String, f32>)]
logit_bias: Vec<(String, f32)>,
}
@ -125,6 +125,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
logit_bias,
} = args;
dbg!(&logit_bias);
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
// Tokenizer instance
@ -213,3 +215,18 @@ fn init_logging() {
.with(fmt_layer)
.init();
}
// Taken from https://github.com/clap-rs/clap/blob/master/examples/typed-derive.rs#L48
// Used to parse LogitBias's
pub(crate) fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
where
T: std::str::FromStr,
T::Err: std::error::Error + Send + Sync + 'static,
U: std::str::FromStr,
U::Err: std::error::Error + Send + Sync + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}