From eb0194a9c1b143168a86045da9c3dd26d83586f1 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 10 Feb 2025 01:54:45 -0800 Subject: [PATCH] fix qwen2 vl crash in continous batching Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f268e499..e268af8b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -750,7 +750,16 @@ class FlashCausalLMBatch(Batch): adapter_segment_builder = None else: input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) + if ( + batches[0].position_ids is not None + and batches[0].position_ids.dim() == 2 + ): + # Qwen2_vl case: + position_ids = batches[0].position_ids.new_empty( + (total_batch_size, batches[0].position_ids.shape[-1]) + ) + else: + position_ids = batches[0].position_ids.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size @@ -2133,7 +2142,11 @@ class FlashCausalLM(Model): if not prefill or (prefill and finished_prefilling): batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids - batch.position_ids += accepted_ids + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids