fix: refactor/simplify conditionals

This commit is contained in:
drbh 2025-02-18 23:36:02 +00:00
parent e4e6ea2598
commit 05333b7cbe

View File

@ -236,10 +236,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if (
config.model_type == "qwen2_vl"
or config.model_type == "qwen2_5_vl"
):
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if image.width <= 20:
w = image.width * 2
h = image.height * 2
@ -430,10 +427,7 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if (
self.model.config.model_type == "qwen2_vl"
or self.model.config.model_type == "qwen2_5_vl"
):
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw