From f2d8c2e76fe3ab9e22e5f997953d695618a575f0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 22 Apr 2024 09:54:51 +0000 Subject: [PATCH] Fixing features for llava_next. Still issues with warmup and truncation atm. --- integration-tests/conftest.py | 8 ++++ integration-tests/models/test_llava_next.py | 1 + router/client/src/client.rs | 2 +- router/src/config.rs | 44 +++++++++++++++---- .../models/custom_modeling/llava_next.py | 4 ++ .../models/vlm_causal_lm.py | 36 +++++++++++---- 6 files changed, 77 insertions(+), 18 deletions(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index cf0f498d..ae3f977b 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -293,6 +293,7 @@ def launcher(event_loop): dtype: Optional[str] = None, revision: Optional[str] = None, max_input_length: Optional[int] = None, + max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, ): port = random.randint(8000, 10_000) @@ -334,6 +335,9 @@ def launcher(event_loop): if max_input_length: args.append("--max-input-length") args.append(str(max_input_length)) + if max_batch_prefill_tokens: + args.append("--max-batch-prefill-tokens") + args.append(str(max_batch_prefill_tokens)) if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) @@ -371,6 +375,7 @@ def launcher(event_loop): dtype: Optional[str] = None, revision: Optional[str] = None, max_input_length: Optional[int] = None, + max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, ): port = random.randint(8000, 10_000) @@ -395,6 +400,9 @@ def launcher(event_loop): if max_input_length: args.append("--max-input-length") args.append(str(max_input_length)) + if max_batch_prefill_tokens: + args.append("--max-batch-prefill-tokens") + args.append(str(max_batch_prefill_tokens)) if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py index f5b290b1..b407d4aa 100644 --- a/integration-tests/models/test_llava_next.py +++ b/integration-tests/models/test_llava_next.py @@ -15,6 +15,7 @@ def flash_llava_next_handle(launcher): "llava-hf/llava-v1.6-mistral-7b-hf", num_shard=4, max_input_length=4000, + max_batch_prefill_tokens=8000, max_total_tokens=4096, ) as handle: yield handle diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 545cddd0..24ecd2ad 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -114,8 +114,8 @@ impl Client { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let mut inputs = String::new(); - inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)"); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)"); requests.push(Request { id: 0, diff --git a/router/src/config.rs b/router/src/config.rs index 8c9fa33f..b050a4d9 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -57,20 +57,46 @@ fn select_best_resolution( best_fit.unwrap_or((original_height, original_width)) } +fn get_unpadded_features( + height: usize, + width: usize, + npatches: usize, + num_patch_height: usize, + num_patch_width: usize, +) -> (usize, usize) { + let current_height = npatches * num_patch_height; + let current_width = npatches * num_patch_width; + + let aspect_ratio: f64 = width as f64 / height as f64; + let current_aspect_ratio: f64 = current_width as f64 / current_height as f64; + let (current_height, current_width) = if aspect_ratio > current_aspect_ratio { + let new_height = (height * current_width) / width; + // let padding = (current_height - new_height) / 2; + // let new_height = current_height - 2 * padding; + (new_height, current_width) + } else { + let new_width = (width * current_height) / height; + (current_height, new_width) + }; + println!("{current_height} {current_width}"); + + let unpadded_features = current_height * current_width; + let newline_features = current_height; + (unpadded_features, newline_features) +} + impl LlavaNext { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { let image_size = self.vision_config.image_size; let patch_size = self.vision_config.patch_size; assert!(image_size % patch_size == 0); let npatches = image_size / patch_size; + println!("{npatches} {image_size} {patch_size}"); let (num_patch_height, num_patch_width) = get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); - // Ceil - // TODO Very odd artifact when the rounding is super close - let height_of_patch = (height * npatches + width - 10) / width; - let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width; - // They are only added after width - let newline_features = height_of_patch * num_patch_width; + + let (unpadded_features, newline_features) = + get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width); // The base patch covers the entire image let base_features = npatches.pow(2); unpadded_features + newline_features + base_features @@ -159,14 +185,16 @@ mod test { ], }; + let slots = config.get_number_of_features(20, 20); + assert_eq!(slots, 1176); let slots = config.get_number_of_features(640, 640); assert_eq!(slots, 2928); let slots = config.get_number_of_features(480, 640); assert_eq!(slots, 2340); let slots = config.get_number_of_features(899, 1024); - assert_eq!(slots, 2732); + assert_eq!(slots, 2634); let slots = config.get_number_of_features(1024, 899); - assert_eq!(slots, 3320); + assert_eq!(slots, 2640); let slots = config.get_number_of_features(1067, 1600); assert_eq!(slots, 2144); } diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 14bf19e1..0f1944a7 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -229,6 +229,10 @@ class LlavaNextForConditionalGeneration(nn.Module): self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) + if image_sizes[image_idx][0].item() == 22: + import ipdb + + ipdb.set_trace() image_feature = image_feature.view( num_patch_height, num_patch_width, height, width, -1 ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1e60ab1f..5394feb5 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -64,7 +64,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def image_text_replacement(image_input, config) -> str: +def image_text_replacement(image_input, config, image_id) -> str: if config.model_type == "idefics2": # TODO technically depends on image splitting which is not implemented. num_features = 320 @@ -74,7 +74,7 @@ def image_text_replacement(image_input, config) -> str: + "" ) elif config.model_type == "llava_next": - height, width = image_input["image_sizes"][0] + height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) from loguru import logger @@ -84,6 +84,26 @@ def image_text_replacement(image_input, config) -> str: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") +def get_unpadded_features( + height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int +) -> Tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio: float = width / height + current_aspect_ratio: float = current_width / current_height + if aspect_ratio > current_aspect_ratio: + new_height = (height * current_width) // width + current_height = new_height + else: + new_width = (width * current_height) // height + current_width = new_width + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + def get_number_of_features(height: int, width: int, config) -> int: # From config # Hardcoded for CLIP for now @@ -101,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int: image_grid_pinpoints, image_size, ) - - height_of_patch = (height * npatches + width - 10) // width - - unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width - # They are only added after width - newline_features = height_of_patch * num_patch_width + unpadded_features, newline_features = get_unpadded_features( + height, width, npatches, num_patch_height, num_patch_width + ) # The base patch covers the entire image base_features = npatches**2 return unpadded_features + newline_features + base_features @@ -149,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch): for r in requests: chunks = split(r.inputs) full_text = "" + image_id = 0 for chunk in chunks: if chunk["type"] == "text": full_text += chunk["content"] @@ -166,7 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch): "Cannot process input image not starting with data:" ) image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement(image_input, config) + full_text += image_text_replacement(image_input, config, image_id) image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk['type']}")