mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-14 05:12:08 +00:00
Upgrade resolution system for less errors in resolution.
This commit is contained in:
parent
5eb6ea0063
commit
32f6416358
@ -33,11 +33,13 @@ impl RadixAllocator {
|
|||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
prefix_caching: bool,
|
prefix_caching: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
assert_eq!(
|
if prefix_caching {
|
||||||
block_size, 1,
|
assert_eq!(
|
||||||
"Radix tree allocator only works with block_size=1, was: {}",
|
block_size, 1,
|
||||||
block_size
|
"Radix tree allocator only works with block_size=1, was: {}",
|
||||||
);
|
block_size
|
||||||
|
);
|
||||||
|
}
|
||||||
// if window_size.is_some() {
|
// if window_size.is_some() {
|
||||||
// unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
// unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
||||||
// }
|
// }
|
||||||
|
12
flake.lock
12
flake.lock
@ -835,11 +835,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1724206841,
|
"lastModified": 1724379657,
|
||||||
"narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=",
|
"narHash": "sha256-+CFDh1FUgyY7q0FiWhKJpHS7LlD3KbiqN5Z4Z+4bGmc=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61",
|
"rev": "a18034322c7703fcfe5d7352a77981ba4a936a61",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -944,11 +944,11 @@
|
|||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1724218652,
|
"lastModified": 1724270760,
|
||||||
"narHash": "sha256-Y7Kt+AZRIdo7tr/VhKGzdwYf7stiYQ4JD7flusEpXQw=",
|
"narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"rev": "ab2761aa7b970e737492b8cc41ca580dcb094808",
|
"rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -56,6 +56,7 @@
|
|||||||
in
|
in
|
||||||
{
|
{
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
|
|
||||||
default = pure;
|
default = pure;
|
||||||
|
|
||||||
pure = mkShell {
|
pure = mkShell {
|
||||||
|
@ -24,36 +24,38 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
|||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
fn resolve_attention(config: &Config, lora_adapters: &Option<String>) -> (String, String) {
|
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
match config.head_dim {
|
if let Some(config) = config {
|
||||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
match config.head_dim {
|
||||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||||
prefix_caching = Some("0".to_string());
|
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||||
}
|
prefix_caching = Some("0".to_string());
|
||||||
match config.model_type.as_deref() {
|
}
|
||||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
match config.model_type.as_deref() {
|
||||||
// Required because gemma2 needs bfloat16 which is not supported by
|
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||||
// flashinfer ?
|
// Required because gemma2 needs bfloat16 which is not supported by
|
||||||
if prefix_caching.is_none() {
|
// flashinfer ?
|
||||||
tracing::info!(
|
if prefix_caching.is_none() {
|
||||||
"Forcing flash decoding because model {} requires it",
|
tracing::info!(
|
||||||
config.model_type.as_ref().unwrap()
|
"Forcing flash decoding because model {} requires it",
|
||||||
);
|
config.model_type.as_ref().unwrap()
|
||||||
prefix_caching = Some("0".to_string());
|
);
|
||||||
attention = Some("flashdecoding".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
attention = Some("flashdecoding".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
_ => {}
|
|
||||||
}
|
}
|
||||||
}
|
_ => {
|
||||||
_ => {
|
if prefix_caching.is_none() {
|
||||||
if prefix_caching.is_none() {
|
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
prefix_caching = Some("0".to_string());
|
||||||
prefix_caching = Some("0".to_string());
|
attention = Some("flashdecoding".to_string());
|
||||||
attention = Some("flashdecoding".to_string());
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1502,68 +1504,68 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
|
|
||||||
tracing::info!("{:#?}", args);
|
tracing::info!("{:#?}", args);
|
||||||
|
|
||||||
let get_max_positions_quantize =
|
let get_config = || -> Result<Config, Box<dyn std::error::Error>> {
|
||||||
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
|
let model_id = args.model_id.clone();
|
||||||
let model_id = args.model_id.clone();
|
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
let filename = if !path.exists() {
|
||||||
let filename = if !path.exists() {
|
// Assume it's a hub id
|
||||||
// Assume it's a hub id
|
|
||||||
|
|
||||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||||
// env variable has precedence over on file token.
|
// env variable has precedence over on file token.
|
||||||
ApiBuilder::new().with_token(Some(token)).build()?
|
ApiBuilder::new().with_token(Some(token)).build()?
|
||||||
} else {
|
|
||||||
Api::new()?
|
|
||||||
};
|
|
||||||
let repo = if let Some(ref revision) = args.revision {
|
|
||||||
api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
api.model(model_id)
|
|
||||||
};
|
|
||||||
repo.get("config.json")?
|
|
||||||
} else {
|
} else {
|
||||||
path.push("config.json");
|
Api::new()?
|
||||||
path
|
|
||||||
};
|
};
|
||||||
|
let repo = if let Some(ref revision) = args.revision {
|
||||||
let content = std::fs::read_to_string(filename)?;
|
api.repo(Repo::with_revision(
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
let config: Config = config.into();
|
revision.to_string(),
|
||||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
))
|
||||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
|
||||||
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
|
||||||
std::env::set_var("ATTENTION", attention);
|
|
||||||
let quantize = config.quantize;
|
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
|
||||||
let max_default = 4096;
|
|
||||||
|
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
|
||||||
if max_position_embeddings > max_default {
|
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
Ok((max_default, quantize))
|
|
||||||
} else {
|
|
||||||
Ok((max_position_embeddings, quantize))
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
Err(Box::new(LauncherError::ArgumentValidation(
|
api.model(model_id)
|
||||||
"no max defined".to_string(),
|
};
|
||||||
)))
|
repo.get("config.json")?
|
||||||
}
|
} else {
|
||||||
|
path.push("config.json");
|
||||||
|
path
|
||||||
};
|
};
|
||||||
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
|
|
||||||
get_max_positions_quantize().unwrap_or((4096, None));
|
let content = std::fs::read_to_string(filename)?;
|
||||||
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
|
let config: Config = config.into();
|
||||||
|
Ok(config)
|
||||||
|
};
|
||||||
|
let config: Option<Config> = get_config().ok();
|
||||||
|
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||||
|
// Quantization usually means you're even more RAM constrained.
|
||||||
|
let max_default = 4096;
|
||||||
|
|
||||||
|
let max_position_embeddings = if let Some(config) = &config {
|
||||||
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
|
if max_position_embeddings > max_default {
|
||||||
|
let max = max_position_embeddings;
|
||||||
|
if args.max_input_tokens.is_none()
|
||||||
|
&& args.max_total_tokens.is_none()
|
||||||
|
&& args.max_batch_prefill_tokens.is_none()
|
||||||
|
{
|
||||||
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
|
}
|
||||||
|
max_default
|
||||||
|
} else {
|
||||||
|
max_position_embeddings
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
max_default
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
max_default
|
||||||
|
};
|
||||||
|
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||||
|
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||||
|
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
||||||
|
std::env::set_var("ATTENTION", attention);
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
match (args.max_input_tokens, args.max_input_length) {
|
match (args.max_input_tokens, args.max_input_length) {
|
||||||
|
Loading…
Reference in New Issue
Block a user