From 6cb0cb68b455c379959490ed41cf26303c544da7 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 4 Feb 2025 00:25:59 +0000 Subject: [PATCH] fix: improve and simplify get_cos_sin, refactors and cleanup get_position_ids --- .../text_generation_server/layers/rotary.py | 44 +++++-------- .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../models/custom_modeling/qwen2_vl.py | 62 +++++++++---------- 3 files changed, 47 insertions(+), 65 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index b40d413a..576aeb52 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -88,22 +88,16 @@ class PositionRotaryEmbedding(nn.Module): rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) mrope_section = rope_scaling.get("mrope_section", None) - # only apply mrope if sections are provided and the rope type is mrope and a section is provided - if mrope_section is not None and rope_type == "mrope": - mrope_section = rope_scaling.get("mrope_section") - return RotaryPositionEmbeddingMultimodalSections( - inv_freq, scaling_factor, mrope_section - ) - if rope_type == "linear": pass elif rope_type == "default": pass elif rope_type == "mrope": mrope_section = rope_scaling["mrope_section"] - return RotaryPositionEmbeddingMultimodalSections( - inv_freq, scaling_factor, mrope_section - ) + if mrope_section is not None: + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, scaling_factor, mrope_section + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -569,6 +563,12 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): self.sections = sections self._cos_cached = None self._sin_cached = None + self.section_indices = ( + torch.arange(len(self.sections)) + .repeat_interleave(torch.tensor(self.sections)) + .view(1, 1, -1) + .to(inv_freq.device) + ) def forward( self, @@ -599,6 +599,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) + self._sections = self.section_indices.expand(seqlen, -1, -1) def get_cos_sin( self, @@ -607,24 +608,11 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): dtype: torch.dtype, ): self._update_cos_sin_cache(dtype, position_ids.device, max_s) + slen = position_ids.shape[0] - # access freqs for each of the 3 sections and stack them - cos_c = torch.stack( - [self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0 - ) - sin_c = torch.stack( - [self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0 - ) + cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) + sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) - # chunk based on sections - split_cos = torch.split(cos_c, self.sections, dim=-1) - split_sin = torch.split(sin_c, self.sections, dim=-1) - - # for each section, select the corresponding cos/sin (0, 1, 2, ...) - cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) - sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) - - # double the size and add a batch dimension - cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1) - sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1) + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) return cos, sin diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 78ae3020..d6569a1d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -235,8 +235,7 @@ class Qwen2Layer(nn.Module): max_s, prefill_cache_indices, ): - residual = hidden_states - normed_hidden_states, _ = self.input_layernorm(hidden_states) + normed_hidden_states, residual = self.input_layernorm(hidden_states) # Self Attention attn_output = self.self_attn( @@ -254,8 +253,7 @@ class Qwen2Layer(nn.Module): hidden_states = attn_output + residual # faster post attention rms norm - residual = hidden_states - hidden_states, _ = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states) mlp_output = self.mlp(hidden_states) hidden_states = mlp_output + residual return hidden_states diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 4031fe8f..2d017e38 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -222,10 +222,10 @@ class Qwen2VLVisionBlock(nn.Module): def forward( self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen ) -> torch.Tensor: - norm1_out, _ = self.norm1(hidden_states) + norm1_out, residual = self.norm1(hidden_states) attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) - hidden_states = hidden_states + attn_out - norm2_out, _ = self.norm2(hidden_states) + hidden_states = attn_out + residual + norm2_out, residual = self.norm2(hidden_states) hidden_states = hidden_states + self.mlp(norm2_out) return hidden_states @@ -410,52 +410,52 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) self.device = weights.device + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) def get_position_ids( self, input_ids: torch.Tensor, image_grid_thw: Optional[torch.Tensor] = None, ) -> torch.Tensor: if image_grid_thw is None: - # (batch_size, 3) return ( torch.arange(input_ids.shape[0], device=input_ids.device) .unsqueeze(1) .repeat(1, 3) ) - # if image grid provided than we need to calculate the position ids spatial_merge_size = self.spatial_merge_size vision_start_token_id = self.vision_start_token_id vision_end_token_id = self.vision_end_token_id - device = input_ids.device dtype = input_ids.dtype input_ids_len = input_ids.shape[0] - # capture vision segments - starts = torch.where(input_ids == vision_start_token_id)[0] - ends = torch.where(input_ids == vision_end_token_id)[0] - # ie. [[ 14, 2181], [2212, 4379]] - vision_segments = torch.stack((starts, ends), dim=1) - # capture text lengths as the space between vision segments - - prev_end = torch.cat( # shift to the left to get the previous end - [torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]] - ) # ie. [0, 2181] - - # text is the space between the end of one vision segment and the start of the next - text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32] - - # calculate the max id from the image width for each segment + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 vision_widths_max = torch.cat( [ torch.zeros(1, device=image_grid_thw.device, dtype=dtype), image_grid_thw[:-1, 2] // spatial_merge_size, ] ) - total_segment_lengths = vision_widths_max + text_lengths - total_segment_lengths = total_segment_lengths.cumsum(dim=0) - text_diff = total_segment_lengths - text_lengths + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision # create position ids for each vision segment based on the image grid llm_pos_ids_list = [] @@ -471,29 +471,25 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) # offset by the position of the last vision segment - im = image_position_ids + total_segment_lengths[i] + im = image_position_ids + vision_segment_lengths[i] llm_pos_ids_list.append(im) # create position ids for each text segment text_ranges = [ torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) - + text_diff[i] - for i, seq_len in enumerate(text_lengths) - ] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]] + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] - # combine by alternating text and vision segments (text, vision, text, vision, ...) full_llm_pos_ids_list = [ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist ] - - # the final segment is the difference between the last vision segment and the end of the input max_s = full_llm_pos_ids_list[-1].max() + 1 - final_text_len = input_ids_len - ends[-1] + final_text_len = input_ids_len - vision_ends[-1] if final_text_len > 0: m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) full_llm_pos_ids_list.append(m + max_s) - # concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) position_ids = ( torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) )