mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-06 17:32:09 +00:00
feat: improve llava next pooling for granite vision
This commit is contained in:
parent
1ff9d185d5
commit
30bdf922bd
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
"""PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@ -115,12 +115,27 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
|
||||
vision_feature_layer = []
|
||||
# If the vision_feature_layer is an int, we assume it is the number of layers
|
||||
if isinstance(config.vision_feature_layer, int):
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
# vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
num = vision_config.num_hidden_layers + config.vision_feature_layer + 1
|
||||
vision_feature_layer = [num]
|
||||
else:
|
||||
# vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
num_hidden_layers = [config.vision_feature_layer + 1]
|
||||
elif isinstance(config.vision_feature_layer, list):
|
||||
# If the vision_feature_layer is a list, we assume it is a list of layer indices
|
||||
# and we select the hidden states at those layers
|
||||
|
||||
vision_feature_layer = config.vision_feature_layer
|
||||
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
@ -194,6 +209,13 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
# vision_feature_layer is a list of layer indices, we select the hidden states at those layers
|
||||
hs_pool = [
|
||||
image_features.hidden_states[layer_idx]
|
||||
for layer_idx in self.vision_feature_layer
|
||||
]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
|
@ -358,6 +358,8 @@ class SiglipEncoder(nn.Module):
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
# Pre-allocate reusable list to avoid memory allocation during forward pass
|
||||
self._hidden_states_buffer = [None] * config.num_hidden_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -365,13 +367,15 @@ class SiglipEncoder(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states, _ = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
self._hidden_states_buffer[idx] = hidden_states
|
||||
|
||||
return hidden_states
|
||||
return self._hidden_states_buffer
|
||||
|
||||
|
||||
class SiglipVisionTransformer(nn.Module):
|
||||
@ -393,18 +397,22 @@ class SiglipVisionTransformer(nn.Module):
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
# make sure the pixel values are the correct dtype
|
||||
pixel_values = pixel_values.to(
|
||||
dtype=self.embeddings.patch_embedding.weight.dtype
|
||||
)
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
|
||||
# NOTE: up until this point, the code logits are exactly
|
||||
# the same as the transformers code. The values evaulate
|
||||
# slightly differently in our encoder layer.
|
||||
encoder_outputs = self.encoder(
|
||||
all_encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
)
|
||||
last_hidden_state = encoder_outputs
|
||||
last_hidden_state = all_encoder_outputs[-1]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
# pooler_output=pooled_output,
|
||||
# hidden_states=encoder_outputs,
|
||||
hidden_states=all_encoder_outputs,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
def load_text_model(prefix, config, weights, name=None):
|
||||
if config.model_type == "llama":
|
||||
if config.model_type == "llama" or config.model_type == "granite":
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
@ -247,8 +247,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
|
||||
assert image_size % patch_size == 0
|
||||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
|
Loading…
Reference in New Issue
Block a user