feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl

This commit is contained in:
David Holtz 2024-10-30 18:46:05 +00:00
parent befd9f6735
commit 8648212c76
4 changed files with 15 additions and 11 deletions

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher): def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle: with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
yield handle yield handle

View File

@ -377,9 +377,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_position_ids( def get_position_ids(
self, self,
batch_input_ids: torch.Tensor, batch_input_ids: torch.Tensor,
image_grid_thw: Optional[torch.LongTensor], image_grid_thw: Optional[torch.LongTensor] = None,
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment # video_grid_thw is not implemented yet as we do not accept video inputs at the moment
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if batch_input_ids.dim() == 1:
batch_input_ids = batch_input_ids.unsqueeze(0)
position_ids = torch.ones( position_ids = torch.ones(
3, 3,
batch_input_ids.shape[0], batch_input_ids.shape[0],

View File

@ -1430,6 +1430,10 @@ class FlashCausalLM(Model):
else: else:
state = None state = None
if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1:
position_ids = self.model.get_position_ids(input_ids)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
"input_ids": input_ids, "input_ids": input_ids,
@ -1806,7 +1810,7 @@ class FlashCausalLM(Model):
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
@ -1981,7 +1985,7 @@ class FlashCausalLM(Model):
# instantly become of shape [BATCH_SIZE] # instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling: if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1 indices = batch.cu_seqlen_prefill[1:] - 1
batch.position_ids = batch.position_ids[indices] batch.position_ids = batch.position_ids[(..., indices)]
batch.slot_indices = batch.slot_indices[indices] batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices indices

View File

@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if hasattr(self.model, "get_position_ids"): if self.model.config.model_type == "qwen2_vl":
if position_ids.shape[0] != 1: if position_ids.dim() == 1:
position_ids = self.model.get_position_ids( position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw input_ids, batch.image_grid_thw
) )
batch.position_ids = position_ids[0, 0, :] batch.position_ids = position_ids
else:
position_ids = position_ids.repeat(3, 1, 1).clone()
batch.position_ids = position_ids[0, 0, :]
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache