Fixing features for llava_next. Still issues with warmup and truncation

atm.
This commit is contained in:
Nicolas Patry 2024-04-22 09:54:51 +00:00
parent ae2b4e1c23
commit f2d8c2e76f
6 changed files with 77 additions and 18 deletions

View File

@ -293,6 +293,7 @@ def launcher(event_loop):
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -334,6 +335,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))
@ -371,6 +375,7 @@ def launcher(event_loop):
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -395,6 +400,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))

View File

@ -15,6 +15,7 @@ def flash_llava_next_handle(launcher):
"llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf",
num_shard=4, num_shard=4,
max_input_length=4000, max_input_length=4000,
max_batch_prefill_tokens=8000,
max_total_tokens=4096, max_total_tokens=4096,
) as handle: ) as handle:
yield handle yield handle

View File

@ -114,8 +114,8 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
requests.push(Request { requests.push(Request {
id: 0, id: 0,

View File

@ -57,20 +57,46 @@ fn select_best_resolution(
best_fit.unwrap_or((original_height, original_width)) best_fit.unwrap_or((original_height, original_width))
} }
fn get_unpadded_features(
height: usize,
width: usize,
npatches: usize,
num_patch_height: usize,
num_patch_width: usize,
) -> (usize, usize) {
let current_height = npatches * num_patch_height;
let current_width = npatches * num_patch_width;
let aspect_ratio: f64 = width as f64 / height as f64;
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
// let padding = (current_height - new_height) / 2;
// let new_height = current_height - 2 * padding;
(new_height, current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
};
println!("{current_height} {current_width}");
let unpadded_features = current_height * current_width;
let newline_features = current_height;
(unpadded_features, newline_features)
}
impl LlavaNext { impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size; let image_size = self.vision_config.image_size;
let patch_size = self.vision_config.patch_size; let patch_size = self.vision_config.patch_size;
assert!(image_size % patch_size == 0); assert!(image_size % patch_size == 0);
let npatches = image_size / patch_size; let npatches = image_size / patch_size;
println!("{npatches} {image_size} {patch_size}");
let (num_patch_height, num_patch_width) = let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
// TODO Very odd artifact when the rounding is super close let (unpadded_features, newline_features) =
let height_of_patch = (height * npatches + width - 10) / width; get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
// The base patch covers the entire image // The base patch covers the entire image
let base_features = npatches.pow(2); let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features unpadded_features + newline_features + base_features
@ -159,14 +185,16 @@ mod test {
], ],
}; };
let slots = config.get_number_of_features(20, 20);
assert_eq!(slots, 1176);
let slots = config.get_number_of_features(640, 640); let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928); assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640); let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340); assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024); let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732); assert_eq!(slots, 2634);
let slots = config.get_number_of_features(1024, 899); let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320); assert_eq!(slots, 2640);
let slots = config.get_number_of_features(1067, 1600); let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144); assert_eq!(slots, 2144);
} }

View File

@ -229,6 +229,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
self.config.vision_config.image_size, self.config.vision_config.image_size,
) )
if image_sizes[image_idx][0].item() == 22:
import ipdb
ipdb.set_trace()
image_feature = image_feature.view( image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1 num_patch_height, num_patch_width, height, width, -1
) )

View File

@ -64,7 +64,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config) -> str: def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented. # TODO technically depends on image splitting which is not implemented.
num_features = 320 num_features = 320
@ -74,7 +74,7 @@ def image_text_replacement(image_input, config) -> str:
+ "<fake_token_around_image>" + "<fake_token_around_image>"
) )
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][0] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger from loguru import logger
@ -84,6 +84,26 @@ def image_text_replacement(image_input, config) -> str:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int: def get_number_of_features(height: int, width: int, config) -> int:
# From config # From config
# Hardcoded for CLIP for now # Hardcoded for CLIP for now
@ -101,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints, image_grid_pinpoints,
image_size, image_size,
) )
unpadded_features, newline_features = get_unpadded_features(
height_of_patch = (height * npatches + width - 10) // width height, width, npatches, num_patch_height, num_patch_width
)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image # The base patch covers the entire image
base_features = npatches**2 base_features = npatches**2
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
@ -149,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests: for r in requests:
chunks = split(r.inputs) chunks = split(r.inputs)
full_text = "" full_text = ""
image_id = 0
for chunk in chunks: for chunk in chunks:
if chunk["type"] == "text": if chunk["type"] == "text":
full_text += chunk["content"] full_text += chunk["content"]
@ -166,7 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:" "Cannot process input image not starting with data:"
) )
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config) full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk['type']}")