Fixing features for llava_next. Still issues with warmup and truncation

atm.
This commit is contained in:
Nicolas Patry 2024-04-22 09:54:51 +00:00
parent ae2b4e1c23
commit f2d8c2e76f
6 changed files with 77 additions and 18 deletions

View File

@ -293,6 +293,7 @@ def launcher(event_loop):
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
@ -334,6 +335,9 @@ def launcher(event_loop):
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
@ -371,6 +375,7 @@ def launcher(event_loop):
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
@ -395,6 +400,9 @@ def launcher(event_loop):
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))

View File

@ -15,6 +15,7 @@ def flash_llava_next_handle(launcher):
"llava-hf/llava-v1.6-mistral-7b-hf",
num_shard=4,
max_input_length=4000,
max_batch_prefill_tokens=8000,
max_total_tokens=4096,
) as handle:
yield handle

View File

@ -114,8 +114,8 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
requests.push(Request {
id: 0,

View File

@ -57,20 +57,46 @@ fn select_best_resolution(
best_fit.unwrap_or((original_height, original_width))
}
fn get_unpadded_features(
height: usize,
width: usize,
npatches: usize,
num_patch_height: usize,
num_patch_width: usize,
) -> (usize, usize) {
let current_height = npatches * num_patch_height;
let current_width = npatches * num_patch_width;
let aspect_ratio: f64 = width as f64 / height as f64;
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
// let padding = (current_height - new_height) / 2;
// let new_height = current_height - 2 * padding;
(new_height, current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
};
println!("{current_height} {current_width}");
let unpadded_features = current_height * current_width;
let newline_features = current_height;
(unpadded_features, newline_features)
}
impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size;
let patch_size = self.vision_config.patch_size;
assert!(image_size % patch_size == 0);
let npatches = image_size / patch_size;
println!("{npatches} {image_size} {patch_size}");
let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
// TODO Very odd artifact when the rounding is super close
let height_of_patch = (height * npatches + width - 10) / width;
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
let (unpadded_features, newline_features) =
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
// The base patch covers the entire image
let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features
@ -159,14 +185,16 @@ mod test {
],
};
let slots = config.get_number_of_features(20, 20);
assert_eq!(slots, 1176);
let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732);
assert_eq!(slots, 2634);
let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320);
assert_eq!(slots, 2640);
let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144);
}

View File

@ -229,6 +229,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if image_sizes[image_idx][0].item() == 22:
import ipdb
ipdb.set_trace()
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)

View File

@ -64,7 +64,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def image_text_replacement(image_input, config) -> str:
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
@ -74,7 +74,7 @@ def image_text_replacement(image_input, config) -> str:
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][0]
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
@ -84,6 +84,26 @@ def image_text_replacement(image_input, config) -> str:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
@ -101,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints,
image_size,
)
height_of_patch = (height * npatches + width - 10) // width
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
@ -149,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
@ -166,7 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config)
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")