mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: refactors and calc num features
This commit is contained in:
parent
831a07f990
commit
fb1ae6d24c
@ -3,13 +3,7 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_qwen2_vl_handle(launcher):
|
def flash_qwen2_vl_handle(launcher):
|
||||||
with launcher(
|
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
||||||
"Qwen/Qwen2-VL-7B-Instruct",
|
|
||||||
max_batch_prefill_tokens=2000,
|
|
||||||
max_input_length=2000,
|
|
||||||
max_total_tokens=2001,
|
|
||||||
cuda_graphs=[0],
|
|
||||||
) as handle:
|
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,9 +160,11 @@ pub struct Qwen2Vl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Qwen2Vl {
|
impl Qwen2Vl {
|
||||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||||
// TODO: calculate number of features
|
let num_pixels = height * width;
|
||||||
6000 / 4
|
let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2);
|
||||||
|
let start_and_end_tokens = 2;
|
||||||
|
num_image_tokens + start_and_end_tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -490,7 +490,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||||
|
|
||||||
position_ids = self.get_position_ids(input_ids, image_grid_thw)
|
position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw)
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds.squeeze(0),
|
inputs_embeds=inputs_embeds.squeeze(0),
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -398,9 +398,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
max_k=batch.max_current_length,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
# TODO: remove the unsqueeze(0)
|
input_ids=input_ids,
|
||||||
input_ids=input_ids.unsqueeze(0),
|
position_ids=position_ids,
|
||||||
position_ids=position_ids.unsqueeze(0),
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
Loading…
Reference in New Issue
Block a user