From 2204f91f32a202c35b91c1c108181fdbb2266098 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 6 Jun 2025 14:54:10 +0000 Subject: [PATCH] fix: adjust llava logic and bump snaps --- .../test_flash_llava_next_load.json | 80 +++++++++---------- .../test_flash_llava_next_simple.json | 20 ++--- .../models/custom_modeling/llava_next.py | 15 ++-- 3 files changed, 59 insertions(+), 56 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json index 0390374d..e714822b 100644 --- a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json +++ b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json @@ -9,61 +9,61 @@ "tokens": [ { "id": 13, - "logprob": -0.007621765, + "logprob": -0.052612305, "special": false, "text": "\n" }, { "id": 13, - "logprob": -0.20812988, + "logprob": -0.079589844, "special": false, "text": "\n" }, { "id": 16114, - "logprob": -1.2587891, + "logprob": -1.6865234, "special": false, "text": "Once" }, { "id": 3714, - "logprob": -0.20825195, + "logprob": -0.20983887, "special": false, "text": " upon" }, { "id": 264, - "logprob": -0.0017709732, + "logprob": -0.0014019012, "special": false, "text": " a" }, { "id": 727, - "logprob": -0.011932373, + "logprob": -0.0121154785, "special": false, "text": " time" }, { "id": 28725, - "logprob": -0.17297363, + "logprob": -0.15405273, "special": false, "text": "," }, { "id": 736, - "logprob": -0.9057617, + "logprob": -0.4802246, "special": false, "text": " there" }, { "id": 403, - "logprob": -0.05758667, + "logprob": -0.03289795, "special": false, "text": " was" }, { "id": 264, - "logprob": -0.00970459, + "logprob": -0.01423645, "special": false, "text": " a" } @@ -82,61 +82,61 @@ "tokens": [ { "id": 13, - "logprob": -0.007621765, + "logprob": -0.052612305, "special": false, "text": "\n" }, { "id": 13, - "logprob": -0.20275879, + "logprob": -0.07946777, "special": false, "text": "\n" }, { "id": 16114, - "logprob": -1.2578125, + "logprob": -1.6914062, "special": false, "text": "Once" }, { "id": 3714, - "logprob": -0.2084961, + "logprob": -0.21020508, "special": false, "text": " upon" }, { "id": 264, - "logprob": -0.0017738342, + "logprob": -0.0014238358, "special": false, "text": " a" }, { "id": 727, - "logprob": -0.011932373, + "logprob": -0.012138367, "special": false, "text": " time" }, { "id": 28725, - "logprob": -0.17297363, + "logprob": -0.15625, "special": false, "text": "," }, { "id": 736, - "logprob": -0.9057617, + "logprob": -0.47827148, "special": false, "text": " there" }, { "id": 403, - "logprob": -0.05758667, + "logprob": -0.03289795, "special": false, "text": " was" }, { "id": 264, - "logprob": -0.00970459, + "logprob": -0.01423645, "special": false, "text": " a" } @@ -155,61 +155,61 @@ "tokens": [ { "id": 13, - "logprob": -0.007621765, + "logprob": -0.052246094, "special": false, "text": "\n" }, { "id": 13, - "logprob": -0.20275879, + "logprob": -0.07739258, "special": false, "text": "\n" }, { "id": 16114, - "logprob": -1.2578125, + "logprob": -1.6875, "special": false, "text": "Once" }, { "id": 3714, - "logprob": -0.2084961, + "logprob": -0.20922852, "special": false, "text": " upon" }, { "id": 264, - "logprob": -0.0017738342, + "logprob": -0.0014228821, "special": false, "text": " a" }, { "id": 727, - "logprob": -0.011932373, + "logprob": -0.012130737, "special": false, "text": " time" }, { "id": 28725, - "logprob": -0.17297363, + "logprob": -0.15612793, "special": false, "text": "," }, { "id": 736, - "logprob": -0.9057617, + "logprob": -0.47827148, "special": false, "text": " there" }, { "id": 403, - "logprob": -0.05758667, + "logprob": -0.032928467, "special": false, "text": " was" }, { "id": 264, - "logprob": -0.00970459, + "logprob": -0.014144897, "special": false, "text": " a" } @@ -228,61 +228,61 @@ "tokens": [ { "id": 13, - "logprob": -0.007621765, + "logprob": -0.052978516, "special": false, "text": "\n" }, { "id": 13, - "logprob": -0.20812988, + "logprob": -0.080444336, "special": false, "text": "\n" }, { "id": 16114, - "logprob": -1.2587891, + "logprob": -1.6826172, "special": false, "text": "Once" }, { "id": 3714, - "logprob": -0.20825195, + "logprob": -0.21044922, "special": false, "text": " upon" }, { "id": 264, - "logprob": -0.0017709732, + "logprob": -0.0014238358, "special": false, "text": " a" }, { "id": 727, - "logprob": -0.011932373, + "logprob": -0.012107849, "special": false, "text": " time" }, { "id": 28725, - "logprob": -0.17297363, + "logprob": -0.15405273, "special": false, "text": "," }, { "id": 736, - "logprob": -0.9057617, + "logprob": -0.47875977, "special": false, "text": " there" }, { "id": 403, - "logprob": -0.05758667, + "logprob": -0.03289795, "special": false, "text": " was" }, { "id": 264, - "logprob": -0.00970459, + "logprob": -0.01423645, "special": false, "text": " a" } diff --git a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json index f0f2ee9e..d0b311ca 100644 --- a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json +++ b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json @@ -8,61 +8,61 @@ "tokens": [ { "id": 13, - "logprob": -0.00756073, + "logprob": -0.052612305, "special": false, "text": "\n" }, { "id": 13, - "logprob": -0.20117188, + "logprob": -0.07739258, "special": false, "text": "\n" }, { "id": 16114, - "logprob": -1.2597656, + "logprob": -1.6914062, "special": false, "text": "Once" }, { "id": 3714, - "logprob": -0.20825195, + "logprob": -0.21020508, "special": false, "text": " upon" }, { "id": 264, - "logprob": -0.00178051, + "logprob": -0.0014228821, "special": false, "text": " a" }, { "id": 727, - "logprob": -0.011955261, + "logprob": -0.012123108, "special": false, "text": " time" }, { "id": 28725, - "logprob": -0.17541504, + "logprob": -0.15625, "special": false, "text": "," }, { "id": 736, - "logprob": -0.91308594, + "logprob": -0.47875977, "special": false, "text": " there" }, { "id": 403, - "logprob": -0.058410645, + "logprob": -0.033416748, "special": false, "text": " was" }, { "id": 264, - "logprob": -0.009689331, + "logprob": -0.014137268, "special": false, "text": " a" } diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 1b1f9de5..240555ae 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -133,6 +133,8 @@ class LlavaNextForConditionalGeneration(nn.Module): # and we select the hidden states at those layers vision_feature_layer = config.vision_feature_layer + else: + vision_feature_layer = [vision_config.num_hidden_layers - 1] self.vision_feature_layer = vision_feature_layer @@ -209,12 +211,13 @@ class LlavaNextForConditionalGeneration(nn.Module): f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." ) - # vision_feature_layer is a list of layer indices, we select the hidden states at those layers - hs_pool = [ - image_features.hidden_states[layer_idx] - for layer_idx in self.vision_feature_layer - ] - selected_image_feature = torch.cat(hs_pool, dim=-1) + if image_features.hidden_states is not None: + # vision_feature_layer is a list of layer indices, we select the hidden states at those layers + hs_pool = [ + image_features.hidden_states[layer_idx] + for layer_idx in self.vision_feature_layer + ] + selected_image_feature = torch.cat(hs_pool, dim=-1) image_features = self.multi_modal_projector(selected_image_feature)