diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index f32b1ad4..67889e85 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -96,7 +96,7 @@ struct Args { /// Generation parameter in case you want to specifically test/debug particular /// decoding strategies, for full doc refer to the `text-generation-server` - #[clap(default_values_t="vec![]", long, env)] + #[clap(long, env, value_parser=parse_key_val::)] logit_bias: Vec<(String, f32)>, } @@ -125,6 +125,8 @@ fn main() -> Result<(), Box> { logit_bias, } = args; + dbg!(&logit_bias); + let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); // Tokenizer instance @@ -213,3 +215,18 @@ fn init_logging() { .with(fmt_layer) .init(); } + +// Taken from https://github.com/clap-rs/clap/blob/master/examples/typed-derive.rs#L48 +// Used to parse LogitBias's +pub(crate) fn parse_key_val(s: &str) -> Result<(T, U), Box> + where + T: std::str::FromStr, + T::Err: std::error::Error + Send + Sync + 'static, + U: std::str::FromStr, + U::Err: std::error::Error + Send + Sync + 'static, +{ + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) +} \ No newline at end of file