fix: calc max_seqlen once and small refactors

This commit is contained in:
David Holtz 2024-11-18 15:34:08 +00:00
parent 41dff3147d
commit 70409f09f4

View File

@ -97,7 +97,7 @@ class Qwen2VLAttention(nn.Module):
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
@ -122,7 +122,6 @@ class Qwen2VLAttention(nn.Module):
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# calc maximum sequence length for any batch
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
@ -220,10 +219,12 @@ class Qwen2VLVisionBlock(nn.Module):
weights=weights,
)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
hidden_states_post_norm1, res = self.norm1(hidden_states)
hidden_states = hidden_states + self.attn(
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen
)
hidden_states_post_norm2, res = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
@ -246,7 +247,7 @@ class Qwen2VLPatchMerger(nn.Module):
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states, grid_thw) -> torch.Tensor:
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
@ -307,7 +308,6 @@ class Qwen2VisionModel(nn.Module):
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: Optional[torch.Tensor] = None,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
@ -362,13 +362,13 @@ class Qwen2VisionModel(nn.Module):
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states, grid_thw)
hidden_states = self.merger(hidden_states)
return hidden_states