Making image passes image per image to save VRAM.

This commit is contained in:
Nicolas Patry 2024-04-22 20:40:21 +00:00
parent 60d2757c36
commit 1aa812da43
3 changed files with 65 additions and 48 deletions

View File

@ -115,7 +115,11 @@ impl Client {
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str("![]()"); if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![]()");
}
requests.push(Request { requests.push(Request {
id: 0, id: 0,

View File

@ -745,58 +745,66 @@ class Idefics2ForConditionalGeneration(nn.Module):
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None: if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility all_states = []
pixel_values = pixel_values.view( all_pixel_values = pixel_values
batch_size * num_images, *pixel_values.shape[2:] all_pixel_mask = pixel_attention_mask
) for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0. # Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel() nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum( real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3) dim=(-1, -2, -3)
) != nb_values_per_image ) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous() pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask # Handle the vision attention mask
if pixel_attention_mask is None: if pixel_attention_mask is None:
pixel_attention_mask = torch.ones( pixel_attention_mask = torch.ones(
size=( size=(
pixel_values.size(0), pixel_values.size(0),
pixel_values.size(2), pixel_values.size(2),
pixel_values.size(3), pixel_values.size(3),
), ),
dtype=torch.bool, dtype=torch.bool,
device=pixel_values.device, device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
) )
else: patches_subgrid = patches_subgrid.unfold(
# Remove padding images from the mask/pP p dimension=2, size=patch_size, step=patch_size
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:]
) )
pixel_attention_mask = pixel_attention_mask[ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size # Get sequence from the vision encoder
patches_subgrid = pixel_attention_mask.unfold( image_hidden_states = self.vision_model(
dimension=1, size=patch_size, step=patch_size pixel_values=pixel_values,
) patch_attention_mask=patch_attention_mask,
patches_subgrid = patches_subgrid.unfold( )
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder # Modality projection & resampling
image_hidden_states = self.vision_model( image_hidden_states = self.connector(
pixel_values=pixel_values, image_hidden_states,
patch_attention_mask=patch_attention_mask, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
) )
all_states.append(image_hidden_states)
# Modality projection & resampling image_hidden_states = torch.stack(all_states, dim=0)
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(

View File

@ -154,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots ! # Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds return inputs_embeds
def forward( def forward(