mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-08 18:32:06 +00:00
feat: improve llava next pooling for granite vision
This commit is contained in:
parent
1ff9d185d5
commit
30bdf922bd
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Llava-NeXT model."""
|
"""PyTorch Llava-NeXT model."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@ -115,12 +115,27 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
vision_config = config.vision_config
|
vision_config = config.vision_config
|
||||||
|
|
||||||
|
vision_feature_layer = []
|
||||||
|
# If the vision_feature_layer is an int, we assume it is the number of layers
|
||||||
|
if isinstance(config.vision_feature_layer, int):
|
||||||
# Instead of selecting in hidden_states[-2].
|
# Instead of selecting in hidden_states[-2].
|
||||||
# Instead compute only the n -2 + 1 layers and don't pool
|
# Instead compute only the n -2 + 1 layers and don't pool
|
||||||
if config.vision_feature_layer < 0:
|
if config.vision_feature_layer < 0:
|
||||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
# vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||||
|
num = vision_config.num_hidden_layers + config.vision_feature_layer + 1
|
||||||
|
vision_feature_layer = [num]
|
||||||
else:
|
else:
|
||||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
# vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||||
|
num_hidden_layers = [config.vision_feature_layer + 1]
|
||||||
|
elif isinstance(config.vision_feature_layer, list):
|
||||||
|
# If the vision_feature_layer is a list, we assume it is a list of layer indices
|
||||||
|
# and we select the hidden states at those layers
|
||||||
|
|
||||||
|
vision_feature_layer = config.vision_feature_layer
|
||||||
|
|
||||||
|
self.vision_feature_layer = vision_feature_layer
|
||||||
|
|
||||||
self.vision_tower = load_vision_model(
|
self.vision_tower = load_vision_model(
|
||||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
@ -194,6 +209,13 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# vision_feature_layer is a list of layer indices, we select the hidden states at those layers
|
||||||
|
hs_pool = [
|
||||||
|
image_features.hidden_states[layer_idx]
|
||||||
|
for layer_idx in self.vision_feature_layer
|
||||||
|
]
|
||||||
|
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
# split up image_features for each of the individual images
|
# split up image_features for each of the individual images
|
||||||
|
@ -358,6 +358,8 @@ class SiglipEncoder(nn.Module):
|
|||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
# Pre-allocate reusable list to avoid memory allocation during forward pass
|
||||||
|
self._hidden_states_buffer = [None] * config.num_hidden_layers
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -365,13 +367,15 @@ class SiglipEncoder(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
hidden_states, _ = encoder_layer(
|
hidden_states, _ = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
)
|
)
|
||||||
|
self._hidden_states_buffer[idx] = hidden_states
|
||||||
|
|
||||||
return hidden_states
|
return self._hidden_states_buffer
|
||||||
|
|
||||||
|
|
||||||
class SiglipVisionTransformer(nn.Module):
|
class SiglipVisionTransformer(nn.Module):
|
||||||
@ -393,18 +397,22 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
|
# make sure the pixel values are the correct dtype
|
||||||
|
pixel_values = pixel_values.to(
|
||||||
|
dtype=self.embeddings.patch_embedding.weight.dtype
|
||||||
|
)
|
||||||
hidden_states = self.embeddings(pixel_values)
|
hidden_states = self.embeddings(pixel_values)
|
||||||
|
|
||||||
# NOTE: up until this point, the code logits are exactly
|
# NOTE: up until this point, the code logits are exactly
|
||||||
# the same as the transformers code. The values evaulate
|
# the same as the transformers code. The values evaulate
|
||||||
# slightly differently in our encoder layer.
|
# slightly differently in our encoder layer.
|
||||||
encoder_outputs = self.encoder(
|
all_encoder_outputs = self.encoder(
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
)
|
)
|
||||||
last_hidden_state = encoder_outputs
|
last_hidden_state = all_encoder_outputs[-1]
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=last_hidden_state,
|
last_hidden_state=last_hidden_state,
|
||||||
# pooler_output=pooled_output,
|
# pooler_output=pooled_output,
|
||||||
# hidden_states=encoder_outputs,
|
hidden_states=all_encoder_outputs,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
def load_text_model(prefix, config, weights, name=None):
|
def load_text_model(prefix, config, weights, name=None):
|
||||||
if config.model_type == "llama":
|
if config.model_type == "llama" or config.model_type == "granite":
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
@ -247,8 +247,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
image_size = config.vision_config.image_size
|
image_size = config.vision_config.image_size
|
||||||
patch_size = config.vision_config.patch_size
|
patch_size = config.vision_config.patch_size
|
||||||
|
|
||||||
assert image_size % patch_size == 0
|
|
||||||
|
|
||||||
npatches = image_size // patch_size
|
npatches = image_size // patch_size
|
||||||
|
|
||||||
# Dimensions are intentionally swapped to be bug-compatible with
|
# Dimensions are intentionally swapped to be bug-compatible with
|
||||||
|
Loading…
Reference in New Issue
Block a user