mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix qwen2 vl crash in continous batching (#3004)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ed96ba6503
commit
06dfe9abfe
@ -750,6 +750,15 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_segment_builder = None
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
if (
|
||||
batches[0].position_ids is not None
|
||||
and batches[0].position_ids.dim() == 2
|
||||
):
|
||||
# Qwen2_vl case:
|
||||
position_ids = batches[0].position_ids.new_empty(
|
||||
(total_batch_size, batches[0].position_ids.shape[-1])
|
||||
)
|
||||
else:
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||
@ -2133,6 +2142,10 @@ class FlashCausalLM(Model):
|
||||
if not prefill or (prefill and finished_prefilling):
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
if batch.position_ids.dim() == 2:
|
||||
# Qwen2_vl case:
|
||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||
else:
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user