mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Updating logic + non flash.
This commit is contained in:
parent
10534511ea
commit
6994fa12f8
@ -137,7 +137,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct RawConfig {
|
struct RawConfig {
|
||||||
|
max_position_embeddings: Option<usize>,
|
||||||
|
n_positions: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
|
max_seq_len: Option<usize>,
|
||||||
quantization_config: Option<QuantizationConfig>,
|
quantization_config: Option<QuantizationConfig>,
|
||||||
n_embd: Option<usize>,
|
n_embd: Option<usize>,
|
||||||
hidden_size: Option<usize>,
|
hidden_size: Option<usize>,
|
||||||
@ -157,6 +160,7 @@ struct VisionConfig {}
|
|||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
|
max_position_embeddings: Option<usize>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
@ -166,6 +170,10 @@ struct Config {
|
|||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
fn from(other: RawConfig) -> Self {
|
fn from(other: RawConfig) -> Self {
|
||||||
|
let max_position_embeddings = other
|
||||||
|
.max_position_embeddings
|
||||||
|
.or(other.max_seq_len)
|
||||||
|
.or(other.n_positions);
|
||||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||||
let head_dim = other.head_dim.or_else(|| {
|
let head_dim = other.head_dim.or_else(|| {
|
||||||
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
||||||
@ -187,6 +195,7 @@ impl From<RawConfig> for Config {
|
|||||||
let vision_config = other.vision_config;
|
let vision_config = other.vision_config;
|
||||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||||
Config {
|
Config {
|
||||||
|
max_position_embeddings,
|
||||||
quantize,
|
quantize,
|
||||||
head_dim,
|
head_dim,
|
||||||
model_type,
|
model_type,
|
||||||
@ -479,7 +488,7 @@ struct Args {
|
|||||||
/// `1511` max_new_tokens.
|
/// `1511` max_new_tokens.
|
||||||
/// The larger this value, the larger amount each request will be in your RAM
|
/// The larger this value, the larger amount each request will be in your RAM
|
||||||
/// and the less effective batching can be.
|
/// and the less effective batching can be.
|
||||||
/// Default to min(max_allocatable, max_position_embeddings)
|
/// Default to min(max_position_embeddings, 4096)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
max_total_tokens: Option<usize>,
|
max_total_tokens: Option<usize>,
|
||||||
|
|
||||||
@ -1667,6 +1676,28 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||||
let quantize = config.as_ref().and_then(|c| c.quantize);
|
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
|
let max_default = 4096;
|
||||||
|
|
||||||
|
let max_position_embeddings = if let Some(config) = &config {
|
||||||
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
|
if max_position_embeddings > max_default {
|
||||||
|
let max = max_position_embeddings;
|
||||||
|
if args.max_input_tokens.is_none()
|
||||||
|
&& args.max_total_tokens.is_none()
|
||||||
|
&& args.max_batch_prefill_tokens.is_none()
|
||||||
|
{
|
||||||
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
|
}
|
||||||
|
max_default
|
||||||
|
} else {
|
||||||
|
max_position_embeddings
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
max_default
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
max_default
|
||||||
|
};
|
||||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||||
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||||
@ -1690,15 +1721,8 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
match args.max_batch_prefill_tokens {
|
match args.max_batch_prefill_tokens {
|
||||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||||
None => {
|
None => {
|
||||||
// let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
|
||||||
// max_batch_size * max_input_tokens
|
|
||||||
// } else {
|
|
||||||
// // Adding some edge in order to account for potential block_size alignement
|
|
||||||
// // issue.
|
|
||||||
// max_input_tokens + 50
|
|
||||||
// } as u32;
|
|
||||||
// TODO figure out hardware optimal value
|
// TODO figure out hardware optimal value
|
||||||
let value = 4096;
|
let value = 4096.min(max_position_embeddings as u32);
|
||||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||||
value
|
value
|
||||||
}
|
}
|
||||||
|
@ -1412,7 +1412,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, _batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
|
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
|
||||||
@ -1442,14 +1442,14 @@ class FlashCausalLM(Model):
|
|||||||
max_total_tokens = num_blocks * BLOCK_SIZE
|
max_total_tokens = num_blocks * BLOCK_SIZE
|
||||||
|
|
||||||
else:
|
else:
|
||||||
max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids)
|
max_total_tokens = sum(batch.cache_lengths)
|
||||||
max_input_tokens = (
|
max_input_tokens = (
|
||||||
max_total_tokens - 1
|
max_total_tokens - 1
|
||||||
if max_input_tokens is None
|
if max_input_tokens is None
|
||||||
else max_input_tokens
|
else max_input_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
del batch
|
del _batch, batch
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
|
@ -132,7 +132,13 @@ class Model(ABC):
|
|||||||
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
|
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
|
||||||
) -> Tuple[Optional[int], int, int]:
|
) -> Tuple[Optional[int], int, int]:
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
return None, 0, 0
|
total = sum(len(i) for i in batch.input_ids)
|
||||||
|
if max_total_tokens is None:
|
||||||
|
max_total_tokens = total
|
||||||
|
|
||||||
|
if max_input_tokens is None:
|
||||||
|
max_input_tokens = max_total_tokens - 1
|
||||||
|
return None, max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user