mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix LLaVA-NeXT handling of non-square images
We could get shape mismatches with non-square images, resulting in an exception that crashed the backend. When post-processing an image, features corresponding to padding are removed when padding was needed. This is also reflected in the calculation of the number of image tokens to get the correct number of slots. However, there was a mismatch between the post-processing and the slot calculation. The image post-processing could exclude fewer padding features due to rounding. This change updates the image token calculation to correspond to the image postprocessing. Fixes #1777. While investigating this, I found another issue where the upstream code contains a bug that swaps the height and width dimensions after computing the image grid shape. Since the models were also trained with this bug, we should reproduce the same bug to ensure that we are generating the same features.
This commit is contained in:
parent
9ce4552bae
commit
e7b1d5e422
@ -71,10 +71,12 @@ fn get_unpadded_features(
|
||||
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||
let new_height = (height * current_width) / width;
|
||||
(new_height, current_width)
|
||||
let padding = (current_height - new_height) / 2;
|
||||
(current_height - (2 * padding), current_width)
|
||||
} else {
|
||||
let new_width = (width * current_height) / height;
|
||||
(current_height, new_width)
|
||||
let padding = (current_width - new_width) / 2;
|
||||
(current_height, current_width - (2 * padding))
|
||||
};
|
||||
|
||||
let unpadded_features = current_height * current_width;
|
||||
@ -88,7 +90,9 @@ impl LlavaNext {
|
||||
let patch_size = self.vision_config.patch_size;
|
||||
assert!(image_size % patch_size == 0);
|
||||
let npatches = image_size / patch_size;
|
||||
let (num_patch_height, num_patch_width) =
|
||||
// Dimensions are intentionally swapped to be bug-compatible with
|
||||
// upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
let (num_patch_width, num_patch_height) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
|
||||
let (unpadded_features, newline_features) =
|
||||
|
@ -39,7 +39,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
@ -47,7 +47,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
tuple: The shape of the image patch grid in the format (height, width).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
@ -229,7 +229,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
|
@ -22,7 +22,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
@ -64,19 +64,26 @@ def image_text_replacement(processor, image_input, config, image_id) -> str:
|
||||
|
||||
|
||||
def get_unpadded_features(
|
||||
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
aspect_ratio: float = width / height
|
||||
aspect_ratio: float = original_width / original_height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
current_height = new_height
|
||||
new_height = (original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
current_width = new_width
|
||||
new_width = (original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
@ -95,7 +102,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
[height, width],
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
|
Loading…
Reference in New Issue
Block a user