mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: refactors and calc num features
This commit is contained in:
parent
831a07f990
commit
fb1ae6d24c
@ -3,13 +3,7 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_qwen2_vl_handle(launcher):
|
||||
with launcher(
|
||||
"Qwen/Qwen2-VL-7B-Instruct",
|
||||
max_batch_prefill_tokens=2000,
|
||||
max_input_length=2000,
|
||||
max_total_tokens=2001,
|
||||
cuda_graphs=[0],
|
||||
) as handle:
|
||||
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -160,9 +160,11 @@ pub struct Qwen2Vl {
|
||||
}
|
||||
|
||||
impl Qwen2Vl {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
// TODO: calculate number of features
|
||||
6000 / 4
|
||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||
let num_pixels = height * width;
|
||||
let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2);
|
||||
let start_and_end_tokens = 2;
|
||||
num_image_tokens + start_and_end_tokens
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -490,7 +490,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
|
||||
position_ids = self.get_position_ids(input_ids, image_grid_thw)
|
||||
position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw)
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds.squeeze(0),
|
||||
position_ids=position_ids,
|
||||
|
@ -398,9 +398,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
# TODO: remove the unsqueeze(0)
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids.unsqueeze(0),
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
|
Loading…
Reference in New Issue
Block a user