diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d6b45c1d..b80e0230 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -23,10 +23,28 @@ use tracing_subscriber::EnvFilter; mod env_runtime; +#[derive(Deserialize)] +struct RawConfig { + max_position_embeddings: Option, + n_positions: Option, + max_seq_len: Option, +} + #[derive(Deserialize)] struct Config { max_position_embeddings: Option, - max_seq_len: Option, +} + +impl From for Config { + fn from(other: RawConfig) -> Self { + let max_position_embeddings = other + .max_position_embeddings + .or(other.max_seq_len) + .or(other.n_positions); + Config { + max_position_embeddings, + } + } } #[derive(Clone, Copy, Debug, ValueEnum)] @@ -1324,33 +1342,30 @@ fn main() -> Result<(), LauncherError> { }; let content = std::fs::read_to_string(filename)?; - let config: Config = serde_json::from_str(&content)?; + let config: RawConfig = serde_json::from_str(&content)?; + let config: Config = config.into(); // Quantization usually means you're even more RAM constrained. let max_default = 4096; - let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) { - (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - max_default - } else { - max_position_embeddings + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); } + Ok(max_default) + } else { + Ok(max_position_embeddings) } - _ => { - return Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))); - } - }; - Ok(max_position_embeddings) + } else { + Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))) + } }; let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 97828ffb..5184731f 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -21,6 +21,18 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.globals import set_model_id +try: + from text_generation_server.models.pali_gemma import PaliGemmaBatch + from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLMBatch, + ) + from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + + VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch} +except (ImportError, NotImplementedError): + # These imports can fail on CPU/Non flash. + VLM_BATCH_TYPES = set() + class SignalHandler: KEEP_PROCESSING = True @@ -91,9 +103,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Prefill(self, request, context): start = time.time_ns() - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device - ) + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb_processor( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + generations, next_batch, timings = self.model.generate_token([batch]) self.cache.set(next_batch)