diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 23fdca05..39046f2a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -236,10 +236,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image = Image.open(BytesIO(chunk.image.data)) # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the # default warmup image is 20x20 - if ( - config.model_type == "qwen2_vl" - or config.model_type == "qwen2_5_vl" - ): + if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: if image.width <= 20: w = image.width * 2 h = image.height * 2 @@ -430,10 +427,7 @@ class VlmCausalLM(FlashCausalLM): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if ( - self.model.config.model_type == "qwen2_vl" - or self.model.config.model_type == "qwen2_5_vl" - ): + if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: if position_ids.dim() == 1 and batch.prefilling: position_ids = self.model.get_position_ids( input_ids, batch.image_grid_thw