feat: refactors and calc num features

This commit is contained in:
David Holtz 2024-10-28 16:57:35 +00:00
parent 831a07f990
commit fb1ae6d24c
4 changed files with 9 additions and 14 deletions

View File

@ -3,13 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-7B-Instruct",
max_batch_prefill_tokens=2000,
max_input_length=2000,
max_total_tokens=2001,
cuda_graphs=[0],
) as handle:
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
yield handle

View File

@ -160,9 +160,11 @@ pub struct Qwen2Vl {
}
impl Qwen2Vl {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
// TODO: calculate number of features
6000 / 4
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let num_pixels = height * width;
let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2);
let start_and_end_tokens = 2;
num_image_tokens + start_and_end_tokens
}
}

View File

@ -490,7 +490,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
position_ids = self.get_position_ids(input_ids, image_grid_thw)
position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw)
hidden_states = self.text_model(
inputs_embeds=inputs_embeds.squeeze(0),
position_ids=position_ids,

View File

@ -398,9 +398,8 @@ class VlmCausalLM(FlashCausalLM):
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
# TODO: remove the unsqueeze(0)
input_ids=input_ids.unsqueeze(0),
position_ids=position_ids.unsqueeze(0),
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,