From 8648212c76b01af1d004ba477630703b09c1c2a1 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 30 Oct 2024 18:46:05 +0000 Subject: [PATCH] feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl --- integration-tests/models/test_flash_qwen2_vl.py | 2 +- .../models/custom_modeling/qwen2_vl.py | 5 ++++- .../text_generation_server/models/flash_causal_lm.py | 8 ++++++-- server/text_generation_server/models/vlm_causal_lm.py | 11 ++++------- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 357de2b1..8ece4fed 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_qwen2_vl_handle(launcher): - with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle: + with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: yield handle diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 6ebc3d4e..63dcff72 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -377,9 +377,12 @@ class Qwen2VLForConditionalGeneration(nn.Module): def get_position_ids( self, batch_input_ids: torch.Tensor, - image_grid_thw: Optional[torch.LongTensor], + image_grid_thw: Optional[torch.LongTensor] = None, # video_grid_thw is not implemented yet as we do not accept video inputs at the moment ) -> Tuple[torch.Tensor, torch.Tensor]: + if batch_input_ids.dim() == 1: + batch_input_ids = batch_input_ids.unsqueeze(0) + position_ids = torch.ones( 3, batch_input_ids.shape[0], diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8ab1a811..4fec65a7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1430,6 +1430,10 @@ class FlashCausalLM(Model): else: state = None + if self.model.config.model_type == "qwen2_vl": + if position_ids.dim() == 1: + position_ids = self.model.get_position_ids(input_ids) + graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1806,7 +1810,7 @@ class FlashCausalLM(Model): # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1981,7 +1985,7 @@ class FlashCausalLM(Model): # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: indices = batch.cu_seqlen_prefill[1:] - 1 - batch.position_ids = batch.position_ids[indices] + batch.position_ids = batch.position_ids[(..., indices)] batch.slot_indices = batch.slot_indices[indices] batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 9a3db502..52b06f8d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if hasattr(self.model, "get_position_ids"): - if position_ids.shape[0] != 1: + if self.model.config.model_type == "qwen2_vl": + if position_ids.dim() == 1: position_ids = self.model.get_position_ids( - input_ids.unsqueeze(0), batch.image_grid_thw + input_ids, batch.image_grid_thw ) - batch.position_ids = position_ids[0, 0, :] - else: - position_ids = position_ids.repeat(3, 1, 1).clone() - batch.position_ids = position_ids[0, 0, :] + batch.position_ids = position_ids if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache