Fixing legacy and CPU configs.

This commit is contained in:
Nicolas Patry 2024-05-23 08:52:28 +00:00
parent 2f243a1a15
commit f48b6109fd
3 changed files with 85 additions and 45 deletions

View File

@ -21,10 +21,38 @@ use tracing_subscriber::EnvFilter;
mod env_runtime;
#[derive(Deserialize)]
struct RawConfig {
max_position_embeddings: Option<usize>,
n_positions: Option<usize>,
max_seq_len: Option<usize>,
}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
max_seq_len: Option<usize>,
}
impl From<RawConfig> for Config {
fn from(other: RawConfig) -> Self {
if other.max_position_embeddings.is_some() {
Config {
max_position_embeddings: other.max_position_embeddings,
}
} else if other.max_seq_len.is_some() {
Config {
max_position_embeddings: other.max_seq_len,
}
} else if other.n_positions.is_some() {
Config {
max_position_embeddings: other.n_positions,
}
} else {
Config {
max_position_embeddings: None,
}
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
@ -1309,13 +1337,13 @@ fn main() -> Result<(), LauncherError> {
};
let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
@ -1324,18 +1352,15 @@ fn main() -> Result<(), LauncherError> {
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default
Ok(max_default)
} else {
max_position_embeddings
}
}
_ => {
return Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)));
}
};
Ok(max_position_embeddings)
}
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
}
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);

View File

@ -472,6 +472,7 @@ def get_model(
)
elif model_type == GPT2:
if FLASH_ATTENTION:
try:
return FlashGPT2(
model_id,
revision,
@ -480,6 +481,17 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
except RuntimeError as e:
# Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:

View File

@ -14,13 +14,20 @@ from typing import List, Optional
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import (
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
)
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch}
except (ImportError, NotImplementedError):
VLM_BATCH_TYPES = set()
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
from text_generation_server.models.globals import set_model_id
@ -96,11 +103,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
except ImportError:
pass
if self.model.batch_type in {
IdeficsCausalLMBatch,
VlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
if (
self.model.batch_type in VLM_BATCH_TYPES
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
self.model.tokenizer,
@ -121,11 +126,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context):
start = time.time_ns()
if self.model.batch_type in {
IdeficsCausalLMBatch,
VlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
if (
self.model.batch_type in VLM_BATCH_TYPES
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
self.model.tokenizer,