mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Making image passes image per image to save VRAM.
This commit is contained in:
parent
60d2757c36
commit
1aa812da43
@ -115,7 +115,11 @@ impl Client {
|
|||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
inputs.push_str("");
|
if n_tokens == 0 {
|
||||||
|
// 1 request is enough to test vision heads.
|
||||||
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
inputs.push_str("");
|
||||||
|
}
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
|
@ -745,58 +745,66 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
all_states = []
|
||||||
pixel_values = pixel_values.view(
|
all_pixel_values = pixel_values
|
||||||
batch_size * num_images, *pixel_values.shape[2:]
|
all_pixel_mask = pixel_attention_mask
|
||||||
)
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(
|
||||||
|
dtype=self.dtype
|
||||||
|
) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
# Remove padding images - padding images are full 0.
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
real_images_inds = (pixel_values == 0.0).sum(
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
dim=(-1, -2, -3)
|
dim=(-1, -2, -3)
|
||||||
) != nb_values_per_image
|
) != nb_values_per_image
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
# Handle the vision attention mask
|
# Handle the vision attention mask
|
||||||
if pixel_attention_mask is None:
|
if pixel_attention_mask is None:
|
||||||
pixel_attention_mask = torch.ones(
|
pixel_attention_mask = torch.ones(
|
||||||
size=(
|
size=(
|
||||||
pixel_values.size(0),
|
pixel_values.size(0),
|
||||||
pixel_values.size(2),
|
pixel_values.size(2),
|
||||||
pixel_values.size(3),
|
pixel_values.size(3),
|
||||||
),
|
),
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=pixel_values.device,
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
)
|
)
|
||||||
else:
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
# Remove padding images from the mask/pP p
|
dimension=2, size=patch_size, step=patch_size
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
)
|
||||||
pixel_attention_mask = pixel_attention_mask[
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
real_images_inds
|
|
||||||
].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
# Get sequence from the vision encoder
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
image_hidden_states = self.vision_model(
|
||||||
dimension=1, size=patch_size, step=patch_size
|
pixel_values=pixel_values,
|
||||||
)
|
patch_attention_mask=patch_attention_mask,
|
||||||
patches_subgrid = patches_subgrid.unfold(
|
)
|
||||||
dimension=2, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
# Modality projection & resampling
|
||||||
image_hidden_states = self.vision_model(
|
image_hidden_states = self.connector(
|
||||||
pixel_values=pixel_values,
|
image_hidden_states,
|
||||||
patch_attention_mask=patch_attention_mask,
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||||
)
|
)
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
# Modality projection & resampling
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
image_hidden_states = self.connector(
|
|
||||||
image_hidden_states,
|
|
||||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
|
||||||
)
|
|
||||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
# that simply don't exist
|
# that simply don't exist
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
@ -154,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
# Let's pray we have enabled enough slots !
|
# Let's pray we have enabled enough slots !
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
try:
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||||
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
Reference in New Issue
Block a user