mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Max_seq_len (old mpt config.)
This commit is contained in:
parent
c4ebcea79c
commit
cd07211411
@ -22,7 +22,8 @@ mod env_runtime;
|
|||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: usize,
|
max_position_embeddings: Option<usize>,
|
||||||
|
max_seq_len: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
@ -817,6 +818,14 @@ enum LauncherError {
|
|||||||
WebserverCannotStart,
|
WebserverCannotStart,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl core::fmt::Display for LauncherError {
|
||||||
|
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||||
|
write!(f, "{self:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for LauncherError {}
|
||||||
|
|
||||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||||
// Enter download tracing span
|
// Enter download tracing span
|
||||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||||
@ -1291,12 +1300,21 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
2usize.pow(14)
|
2usize.pow(14)
|
||||||
};
|
};
|
||||||
|
|
||||||
let max_position_embeddings = if config.max_position_embeddings > max_default {
|
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
||||||
let max = config.max_position_embeddings;
|
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max - 1, max - 1);
|
if max_position_embeddings > max_default {
|
||||||
max_default
|
let max = max_position_embeddings;
|
||||||
} else {
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
config.max_position_embeddings
|
max_default
|
||||||
|
} else {
|
||||||
|
max_position_embeddings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(Box::new(LauncherError::ArgumentValidation(
|
||||||
|
"no max defined".to_string(),
|
||||||
|
)));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Ok(max_position_embeddings)
|
Ok(max_position_embeddings)
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user