From fb1ae6d24ca729218e1dc16e2a313c6f70104154 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 16:57:35 +0000 Subject: [PATCH] feat: refactors and calc num features --- integration-tests/models/test_flash_qwen2_vl.py | 8 +------- router/src/config.rs | 8 +++++--- .../models/custom_modeling/qwen2_vl.py | 2 +- server/text_generation_server/models/vlm_causal_lm.py | 5 ++--- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 86582673..73413eb0 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -3,13 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_qwen2_vl_handle(launcher): - with launcher( - "Qwen/Qwen2-VL-7B-Instruct", - max_batch_prefill_tokens=2000, - max_input_length=2000, - max_total_tokens=2001, - cuda_graphs=[0], - ) as handle: + with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: yield handle diff --git a/router/src/config.rs b/router/src/config.rs index 7fc27f96..eb16e88b 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -160,9 +160,11 @@ pub struct Qwen2Vl { } impl Qwen2Vl { - pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { - // TODO: calculate number of features - 6000 / 4 + pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { + let num_pixels = height * width; + let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2); + let start_and_end_tokens = 2; + num_image_tokens + start_and_end_tokens } } diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 3bb29b9b..8eee045a 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -490,7 +490,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) inputs_embeds[input_ids == self.image_token_id] = image_embeds - position_ids = self.get_position_ids(input_ids, image_grid_thw) + position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw) hidden_states = self.text_model( inputs_embeds=inputs_embeds.squeeze(0), position_ids=position_ids, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7625c305..a8467059 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -398,9 +398,8 @@ class VlmCausalLM(FlashCausalLM): max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( - # TODO: remove the unsqueeze(0) - input_ids=input_ids.unsqueeze(0), - position_ids=position_ids.unsqueeze(0), + input_ids=input_ids, + position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables,