From 70409f09f4653d7598f3f1de72ca42f459082f5a Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 18 Nov 2024 15:34:08 +0000 Subject: [PATCH] fix: calc max_seqlen once and small refactors --- .../models/custom_modeling/qwen2_vl.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 70a6a39c..ddb4e36d 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -97,7 +97,7 @@ class Qwen2VLAttention(nn.Module): hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state qkv = self.qkv(hidden_state) @@ -122,7 +122,6 @@ class Qwen2VLAttention(nn.Module): key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) # calc maximum sequence length for any batch - max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) query = query.contiguous() key = key.contiguous() value = value.contiguous() @@ -220,10 +219,12 @@ class Qwen2VLVisionBlock(nn.Module): weights=weights, ) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: hidden_states_post_norm1, res = self.norm1(hidden_states) hidden_states = hidden_states + self.attn( - hidden_states_post_norm1, cu_seqlens, rotary_pos_emb + hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen ) hidden_states_post_norm2, res = self.norm2(hidden_states) hidden_states = hidden_states + self.mlp(hidden_states_post_norm2) @@ -246,7 +247,7 @@ class Qwen2VLPatchMerger(nn.Module): prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True ) - def forward(self, hidden_states, grid_thw) -> torch.Tensor: + def forward(self, hidden_states) -> torch.Tensor: hidden_states, _ = self.patch_merger_ln_q(hidden_states) hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = self.fc1(hidden_states) @@ -307,7 +308,6 @@ class Qwen2VisionModel(nn.Module): def forward( self, pixel_values: torch.Tensor, - aspect_ratio_ids: Optional[torch.Tensor] = None, grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # reshape the input tensor for processing @@ -362,13 +362,13 @@ class Qwen2VisionModel(nn.Module): grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states for block in self.blocks: - hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb) + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) # apply the final patch merger to the hidden states - hidden_states = self.merger(hidden_states, grid_thw) + hidden_states = self.merger(hidden_states) return hidden_states