Upgrade resolution system for less errors in resolution.

This commit is contained in:
Nicolas Patry 2024-08-23 15:27:53 +02:00
parent 5eb6ea0063
commit 32f6416358
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 99 additions and 94 deletions

View File

@ -33,11 +33,13 @@ impl RadixAllocator {
window_size: Option<u32>, window_size: Option<u32>,
prefix_caching: bool, prefix_caching: bool,
) -> Self { ) -> Self {
assert_eq!( if prefix_caching {
block_size, 1, assert_eq!(
"Radix tree allocator only works with block_size=1, was: {}", block_size, 1,
block_size "Radix tree allocator only works with block_size=1, was: {}",
); block_size
);
}
// if window_size.is_some() { // if window_size.is_some() {
// unimplemented!("Window size not supported in the prefix-caching block allocator yet"); // unimplemented!("Window size not supported in the prefix-caching block allocator yet");
// } // }

View File

@ -835,11 +835,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1724206841, "lastModified": 1724379657,
"narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", "narHash": "sha256-+CFDh1FUgyY7q0FiWhKJpHS7LlD3KbiqN5Z4Z+4bGmc=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", "rev": "a18034322c7703fcfe5d7352a77981ba4a936a61",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -944,11 +944,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1724218652, "lastModified": 1724270760,
"narHash": "sha256-Y7Kt+AZRIdo7tr/VhKGzdwYf7stiYQ4JD7flusEpXQw=", "narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=",
"owner": "danieldk", "owner": "danieldk",
"repo": "tgi-nix", "repo": "tgi-nix",
"rev": "ab2761aa7b970e737492b8cc41ca580dcb094808", "rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -56,6 +56,7 @@
in in
{ {
devShells = with pkgs; rec { devShells = with pkgs; rec {
default = pure; default = pure;
pure = mkShell { pure = mkShell {

View File

@ -24,36 +24,38 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
fn resolve_attention(config: &Config, lora_adapters: &Option<String>) -> (String, String) { fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
match config.head_dim { if let Some(config) = config {
Some(h) if h == 64 || h == 128 || h == 256 => { match config.head_dim {
if lora_adapters.is_some() && prefix_caching.is_none() { Some(h) if h == 64 || h == 128 || h == 256 => {
tracing::info!("Disabling prefix caching because of lora adapters"); if lora_adapters.is_some() && prefix_caching.is_none() {
prefix_caching = Some("0".to_string()); tracing::info!("Disabling prefix caching because of lora adapters");
} prefix_caching = Some("0".to_string());
match config.model_type.as_deref() { }
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { match config.model_type.as_deref() {
// Required because gemma2 needs bfloat16 which is not supported by Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// flashinfer ? // Required because gemma2 needs bfloat16 which is not supported by
if prefix_caching.is_none() { // flashinfer ?
tracing::info!( if prefix_caching.is_none() {
"Forcing flash decoding because model {} requires it", tracing::info!(
config.model_type.as_ref().unwrap() "Forcing flash decoding because model {} requires it",
); config.model_type.as_ref().unwrap()
prefix_caching = Some("0".to_string()); );
attention = Some("flashdecoding".to_string()); prefix_caching = Some("0".to_string());
} attention = Some("flashdecoding".to_string());
}
}
_ => {}
} }
_ => {}
} }
} _ => {
_ => { if prefix_caching.is_none() {
if prefix_caching.is_none() { tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); prefix_caching = Some("0".to_string());
prefix_caching = Some("0".to_string()); attention = Some("flashdecoding".to_string());
attention = Some("flashdecoding".to_string()); }
} }
} }
} }
@ -1502,68 +1504,68 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args); tracing::info!("{:#?}", args);
let get_max_positions_quantize = let get_config = || -> Result<Config, Box<dyn std::error::Error>> {
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> { let model_id = args.model_id.clone();
let model_id = args.model_id.clone(); let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let mut path = std::path::Path::new(&args.model_id).to_path_buf(); let filename = if !path.exists() {
let filename = if !path.exists() { // Assume it's a hub id
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") { let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token. // env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()? ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else { } else {
path.push("config.json"); Api::new()?
path
}; };
let repo = if let Some(ref revision) = args.revision {
let content = std::fs::read_to_string(filename)?; api.repo(Repo::with_revision(
let config: RawConfig = serde_json::from_str(&content)?; model_id,
RepoType::Model,
let config: Config = config.into(); revision.to_string(),
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); ))
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let quantize = config.quantize;
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok((max_default, quantize))
} else {
Ok((max_position_embeddings, quantize))
}
} else { } else {
Err(Box::new(LauncherError::ArgumentValidation( api.model(model_id)
"no max defined".to_string(), };
))) repo.get("config.json")?
} } else {
path.push("config.json");
path
}; };
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
get_max_positions_quantize().unwrap_or((4096, None)); let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
};
let config: Option<Config> = get_config().ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default
} else {
max_position_embeddings
}
} else {
max_default
}
} else {
max_default
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = { let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) { match (args.max_input_tokens, args.max_input_length) {