feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl

This commit is contained in:
David Holtz 2024-10-30 18:46:05 +00:00
parent befd9f6735
commit 8648212c76
4 changed files with 15 additions and 11 deletions

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
yield handle

View File

@ -377,9 +377,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_position_ids(
self,
batch_input_ids: torch.Tensor,
image_grid_thw: Optional[torch.LongTensor],
image_grid_thw: Optional[torch.LongTensor] = None,
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
) -> Tuple[torch.Tensor, torch.Tensor]:
if batch_input_ids.dim() == 1:
batch_input_ids = batch_input_ids.unsqueeze(0)
position_ids = torch.ones(
3,
batch_input_ids.shape[0],

View File

@ -1430,6 +1430,10 @@ class FlashCausalLM(Model):
else:
state = None
if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1:
position_ids = self.model.get_position_ids(input_ids)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
@ -1806,7 +1810,7 @@ class FlashCausalLM(Model):
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
@ -1981,7 +1985,7 @@ class FlashCausalLM(Model):
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1
batch.position_ids = batch.position_ids[indices]
batch.position_ids = batch.position_ids[(..., indices)]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices

View File

@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if hasattr(self.model, "get_position_ids"):
if position_ids.shape[0] != 1:
if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1:
position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw
input_ids, batch.image_grid_thw
)
batch.position_ids = position_ids[0, 0, :]
else:
position_ids = position_ids.repeat(3, 1, 1).clone()
batch.position_ids = position_ids[0, 0, :]
batch.position_ids = position_ids
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache