feat: improve llava next pooling for granite vision

This commit is contained in:
drbh 2025-06-04 13:50:39 +00:00
parent 1ff9d185d5
commit 30bdf922bd
4 changed files with 42 additions and 14 deletions

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
"""PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple
@ -115,12 +115,27 @@ class LlavaNextForConditionalGeneration(nn.Module):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
vision_feature_layer = []
# If the vision_feature_layer is an int, we assume it is the number of layers
if isinstance(config.vision_feature_layer, int):
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
# vision_config.num_hidden_layers += config.vision_feature_layer + 1
num = vision_config.num_hidden_layers + config.vision_feature_layer + 1
vision_feature_layer = [num]
else:
# vision_config.num_hidden_layers = config.vision_feature_layer + 1
num_hidden_layers = [config.vision_feature_layer + 1]
elif isinstance(config.vision_feature_layer, list):
# If the vision_feature_layer is a list, we assume it is a list of layer indices
# and we select the hidden states at those layers
vision_feature_layer = config.vision_feature_layer
self.vision_feature_layer = vision_feature_layer
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
@ -194,6 +209,13 @@ class LlavaNextForConditionalGeneration(nn.Module):
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
# vision_feature_layer is a list of layer indices, we select the hidden states at those layers
hs_pool = [
image_features.hidden_states[layer_idx]
for layer_idx in self.vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images

View File

@ -358,6 +358,8 @@ class SiglipEncoder(nn.Module):
for i in range(config.num_hidden_layers)
]
)
# Pre-allocate reusable list to avoid memory allocation during forward pass
self._hidden_states_buffer = [None] * config.num_hidden_layers
def forward(
self,
@ -365,13 +367,15 @@ class SiglipEncoder(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer(
hidden_states,
attention_mask,
)
self._hidden_states_buffer[idx] = hidden_states
return hidden_states
return self._hidden_states_buffer
class SiglipVisionTransformer(nn.Module):
@ -393,18 +397,22 @@ class SiglipVisionTransformer(nn.Module):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# make sure the pixel values are the correct dtype
pixel_values = pixel_values.to(
dtype=self.embeddings.patch_embedding.weight.dtype
)
hidden_states = self.embeddings(pixel_values)
# NOTE: up until this point, the code logits are exactly
# the same as the transformers code. The values evaulate
# slightly differently in our encoder layer.
encoder_outputs = self.encoder(
all_encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
)
last_hidden_state = encoder_outputs
last_hidden_state = all_encoder_outputs[-1]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs,
hidden_states=all_encoder_outputs,
)

View File

@ -1,5 +1,5 @@
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
if config.model_type == "llama" or config.model_type == "granite":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)

View File

@ -247,8 +247,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
# Dimensions are intentionally swapped to be bug-compatible with