mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl
This commit is contained in:
parent
befd9f6735
commit
8648212c76
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_qwen2_vl_handle(launcher):
|
||||
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
|
||||
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -377,9 +377,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def get_position_ids(
|
||||
self,
|
||||
batch_input_ids: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.LongTensor],
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if batch_input_ids.dim() == 1:
|
||||
batch_input_ids = batch_input_ids.unsqueeze(0)
|
||||
|
||||
position_ids = torch.ones(
|
||||
3,
|
||||
batch_input_ids.shape[0],
|
||||
|
@ -1430,6 +1430,10 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
state = None
|
||||
|
||||
if self.model.config.model_type == "qwen2_vl":
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = self.model.get_position_ids(input_ids)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
@ -1806,7 +1810,7 @@ class FlashCausalLM(Model):
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
@ -1981,7 +1985,7 @@ class FlashCausalLM(Model):
|
||||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill and finished_prefilling:
|
||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
batch.position_ids = batch.position_ids[(..., indices)]
|
||||
batch.slot_indices = batch.slot_indices[indices]
|
||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||
indices
|
||||
|
@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if hasattr(self.model, "get_position_ids"):
|
||||
if position_ids.shape[0] != 1:
|
||||
if self.model.config.model_type == "qwen2_vl":
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids.unsqueeze(0), batch.image_grid_thw
|
||||
input_ids, batch.image_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids[0, 0, :]
|
||||
else:
|
||||
position_ids = position_ids.repeat(3, 1, 1).clone()
|
||||
batch.position_ids = position_ids[0, 0, :]
|
||||
batch.position_ids = position_ids
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
|
Loading…
Reference in New Issue
Block a user