fix qwen2 vl crash in continous batching (#3004)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-02-21 07:36:45 +08:00 committed by GitHub
parent ed96ba6503
commit 06dfe9abfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -750,6 +750,15 @@ class FlashCausalLMBatch(Batch):
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if (
batches[0].position_ids is not None
and batches[0].position_ids.dim() == 2
):
# Qwen2_vl case:
position_ids = batches[0].position_ids.new_empty(
(total_batch_size, batches[0].position_ids.shape[-1])
)
else:
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
@ -2133,6 +2142,10 @@ class FlashCausalLM(Model):
if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)