fix qwen2 vl crash in continous batching (#3004)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-02-21 07:36:45 +08:00 committed by GitHub
parent ed96ba6503
commit 06dfe9abfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -750,7 +750,16 @@ class FlashCausalLMBatch(Batch):
adapter_segment_builder = None adapter_segment_builder = None
else: else:
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) if (
batches[0].position_ids is not None
and batches[0].position_ids.dim() == 2
):
# Qwen2_vl case:
position_ids = batches[0].position_ids.new_empty(
(total_batch_size, batches[0].position_ids.shape[-1])
)
else:
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size total_batch_size
@ -2133,7 +2142,11 @@ class FlashCausalLM(Model):
if not prefill or (prefill and finished_prefilling): if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids += accepted_ids if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids