diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 56a9565b..1b1f9de5 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Llava-NeXT model.""" +"""PyTorch Llava-NeXT model.""" from typing import List, Optional, Tuple @@ -115,12 +115,27 @@ class LlavaNextForConditionalGeneration(nn.Module): super().__init__() config.vision_config.quantize = config.quantize vision_config = config.vision_config - # Instead of selecting in hidden_states[-2]. - # Instead compute only the n -2 + 1 layers and don't pool - if config.vision_feature_layer < 0: - vision_config.num_hidden_layers += config.vision_feature_layer + 1 - else: - vision_config.num_hidden_layers = config.vision_feature_layer + 1 + + vision_feature_layer = [] + # If the vision_feature_layer is an int, we assume it is the number of layers + if isinstance(config.vision_feature_layer, int): + # Instead of selecting in hidden_states[-2]. + # Instead compute only the n -2 + 1 layers and don't pool + if config.vision_feature_layer < 0: + # vision_config.num_hidden_layers += config.vision_feature_layer + 1 + num = vision_config.num_hidden_layers + config.vision_feature_layer + 1 + vision_feature_layer = [num] + else: + # vision_config.num_hidden_layers = config.vision_feature_layer + 1 + num_hidden_layers = [config.vision_feature_layer + 1] + elif isinstance(config.vision_feature_layer, list): + # If the vision_feature_layer is a list, we assume it is a list of layer indices + # and we select the hidden states at those layers + + vision_feature_layer = config.vision_feature_layer + + self.vision_feature_layer = vision_feature_layer + self.vision_tower = load_vision_model( prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, @@ -194,6 +209,13 @@ class LlavaNextForConditionalGeneration(nn.Module): f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." ) + # vision_feature_layer is a list of layer indices, we select the hidden states at those layers + hs_pool = [ + image_features.hidden_states[layer_idx] + for layer_idx in self.vision_feature_layer + ] + selected_image_feature = torch.cat(hs_pool, dim=-1) + image_features = self.multi_modal_projector(selected_image_feature) # split up image_features for each of the individual images diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 95ac9ede..909820ff 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -358,6 +358,8 @@ class SiglipEncoder(nn.Module): for i in range(config.num_hidden_layers) ] ) + # Pre-allocate reusable list to avoid memory allocation during forward pass + self._hidden_states_buffer = [None] * config.num_hidden_layers def forward( self, @@ -365,13 +367,15 @@ class SiglipEncoder(nn.Module): attention_mask: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): hidden_states, _ = encoder_layer( hidden_states, attention_mask, ) + self._hidden_states_buffer[idx] = hidden_states - return hidden_states + return self._hidden_states_buffer class SiglipVisionTransformer(nn.Module): @@ -393,18 +397,22 @@ class SiglipVisionTransformer(nn.Module): if pixel_values is None: raise ValueError("You have to specify pixel_values") + # make sure the pixel values are the correct dtype + pixel_values = pixel_values.to( + dtype=self.embeddings.patch_embedding.weight.dtype + ) hidden_states = self.embeddings(pixel_values) # NOTE: up until this point, the code logits are exactly # the same as the transformers code. The values evaulate # slightly differently in our encoder layer. - encoder_outputs = self.encoder( + all_encoder_outputs = self.encoder( inputs_embeds=hidden_states, ) - last_hidden_state = encoder_outputs + last_hidden_state = all_encoder_outputs[-1] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, # pooler_output=pooled_output, - # hidden_states=encoder_outputs, + hidden_states=all_encoder_outputs, ) diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 4447a73f..0e51a7ad 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -1,5 +1,5 @@ def load_text_model(prefix, config, weights, name=None): - if config.model_type == "llama": + if config.model_type == "llama" or config.model_type == "granite": from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index b76dbe68..3fe5d852 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -247,8 +247,6 @@ def get_number_of_features(height: int, width: int, config) -> int: image_size = config.vision_config.image_size patch_size = config.vision_config.patch_size - assert image_size % patch_size == 0 - npatches = image_size // patch_size # Dimensions are intentionally swapped to be bug-compatible with