mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix qwen2 vl crash in continous batching (#3004)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ed96ba6503
commit
06dfe9abfe
@ -750,7 +750,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_segment_builder = None
|
adapter_segment_builder = None
|
||||||
else:
|
else:
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
if (
|
||||||
|
batches[0].position_ids is not None
|
||||||
|
and batches[0].position_ids.dim() == 2
|
||||||
|
):
|
||||||
|
# Qwen2_vl case:
|
||||||
|
position_ids = batches[0].position_ids.new_empty(
|
||||||
|
(total_batch_size, batches[0].position_ids.shape[-1])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
@ -2133,7 +2142,11 @@ class FlashCausalLM(Model):
|
|||||||
if not prefill or (prefill and finished_prefilling):
|
if not prefill or (prefill and finished_prefilling):
|
||||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids += accepted_ids
|
if batch.position_ids.dim() == 2:
|
||||||
|
# Qwen2_vl case:
|
||||||
|
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
batch.position_ids += accepted_ids
|
||||||
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
||||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
|
Loading…
Reference in New Issue
Block a user