feat: refactors and calc num features

This commit is contained in:
David Holtz 2024-10-28 16:57:35 +00:00
parent 831a07f990
commit fb1ae6d24c
4 changed files with 9 additions and 14 deletions

View File

@ -3,13 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher): def flash_qwen2_vl_handle(launcher):
with launcher( with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
"Qwen/Qwen2-VL-7B-Instruct",
max_batch_prefill_tokens=2000,
max_input_length=2000,
max_total_tokens=2001,
cuda_graphs=[0],
) as handle:
yield handle yield handle

View File

@ -160,9 +160,11 @@ pub struct Qwen2Vl {
} }
impl Qwen2Vl { impl Qwen2Vl {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
// TODO: calculate number of features let num_pixels = height * width;
6000 / 4 let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2);
let start_and_end_tokens = 2;
num_image_tokens + start_and_end_tokens
} }
} }

View File

@ -490,7 +490,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
inputs_embeds[input_ids == self.image_token_id] = image_embeds inputs_embeds[input_ids == self.image_token_id] = image_embeds
position_ids = self.get_position_ids(input_ids, image_grid_thw) position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw)
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds.squeeze(0), inputs_embeds=inputs_embeds.squeeze(0),
position_ids=position_ids, position_ids=position_ids,

View File

@ -398,9 +398,8 @@ class VlmCausalLM(FlashCausalLM):
max_k=batch.max_current_length, max_k=batch.max_current_length,
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
# TODO: remove the unsqueeze(0) input_ids=input_ids,
input_ids=input_ids.unsqueeze(0), position_ids=position_ids,
position_ids=position_ids.unsqueeze(0),
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,