mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fixed clap arg parsing.
This commit is contained in:
parent
9b3be8f79b
commit
ed2efe3dd9
@ -96,7 +96,7 @@ struct Args {
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(default_values_t="vec![]", long, env)]
|
||||
#[clap(long, env, value_parser=parse_key_val::<String, f32>)]
|
||||
logit_bias: Vec<(String, f32)>,
|
||||
}
|
||||
|
||||
@ -125,6 +125,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
logit_bias,
|
||||
} = args;
|
||||
|
||||
dbg!(&logit_bias);
|
||||
|
||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||
|
||||
// Tokenizer instance
|
||||
@ -213,3 +215,18 @@ fn init_logging() {
|
||||
.with(fmt_layer)
|
||||
.init();
|
||||
}
|
||||
|
||||
// Taken from https://github.com/clap-rs/clap/blob/master/examples/typed-derive.rs#L48
|
||||
// Used to parse LogitBias's
|
||||
pub(crate) fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr,
|
||||
T::Err: std::error::Error + Send + Sync + 'static,
|
||||
U: std::str::FromStr,
|
||||
U::Err: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let pos = s
|
||||
.find('=')
|
||||
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
|
||||
}
|
Loading…
Reference in New Issue
Block a user