fix: improve and simplify get_cos_sin, refactors and cleanup get_position_ids

This commit is contained in:
drbh 2025-02-04 00:25:59 +00:00
parent 9eaa163239
commit 6cb0cb68b4
3 changed files with 47 additions and 65 deletions

View File

@ -88,19 +88,13 @@ class PositionRotaryEmbedding(nn.Module):
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
mrope_section = rope_scaling.get("mrope_section", None)
# only apply mrope if sections are provided and the rope type is mrope and a section is provided
if mrope_section is not None and rope_type == "mrope":
mrope_section = rope_scaling.get("mrope_section")
return RotaryPositionEmbeddingMultimodalSections(
inv_freq, scaling_factor, mrope_section
)
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "mrope":
mrope_section = rope_scaling["mrope_section"]
if mrope_section is not None:
return RotaryPositionEmbeddingMultimodalSections(
inv_freq, scaling_factor, mrope_section
)
@ -569,6 +563,12 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
self.sections = sections
self._cos_cached = None
self._sin_cached = None
self.section_indices = (
torch.arange(len(self.sections))
.repeat_interleave(torch.tensor(self.sections))
.view(1, 1, -1)
.to(inv_freq.device)
)
def forward(
self,
@ -599,6 +599,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
self._sections = self.section_indices.expand(seqlen, -1, -1)
def get_cos_sin(
self,
@ -607,24 +608,11 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
dtype: torch.dtype,
):
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
slen = position_ids.shape[0]
# access freqs for each of the 3 sections and stack them
cos_c = torch.stack(
[self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0
)
sin_c = torch.stack(
[self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0
)
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
# chunk based on sections
split_cos = torch.split(cos_c, self.sections, dim=-1)
split_sin = torch.split(sin_c, self.sections, dim=-1)
# for each section, select the corresponding cos/sin (0, 1, 2, ...)
cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
# double the size and add a batch dimension
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1)
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1)
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
return cos, sin

View File

@ -235,8 +235,7 @@ class Qwen2Layer(nn.Module):
max_s,
prefill_cache_indices,
):
residual = hidden_states
normed_hidden_states, _ = self.input_layernorm(hidden_states)
normed_hidden_states, residual = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
@ -254,8 +253,7 @@ class Qwen2Layer(nn.Module):
hidden_states = attn_output + residual
# faster post attention rms norm
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = mlp_output + residual
return hidden_states

View File

@ -222,10 +222,10 @@ class Qwen2VLVisionBlock(nn.Module):
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
norm1_out, residual = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
hidden_states = attn_out + residual
norm2_out, residual = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out)
return hidden_states
@ -410,52 +410,52 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
self.device = weights.device
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
# modified to first find segments then initialize position ids for each segment
# Steps:
# locate all vision and text segments
# calculate `vision_segment_lengths` for each vision segment to be use as offset
# calculate `text_segment_lengths` for each text segment to be used as offset
# create position ids for each vision segment based on the image grid
# create position ids for each text segment
# combine all the position ids
# the final segment is the difference between the last vision segment and the end of the input
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if image_grid_thw is None:
# (batch_size, 3)
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
# if image grid provided than we need to calculate the position ids
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
# capture vision segments
starts = torch.where(input_ids == vision_start_token_id)[0]
ends = torch.where(input_ids == vision_end_token_id)[0]
# ie. [[ 14, 2181], [2212, 4379]]
vision_segments = torch.stack((starts, ends), dim=1)
# capture text lengths as the space between vision segments
prev_end = torch.cat( # shift to the left to get the previous end
[torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]]
) # ie. [0, 2181]
# text is the space between the end of one vision segment and the start of the next
text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32]
# calculate the max id from the image width for each segment
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
prev_vision_end = torch.cat(
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
)
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
total_segment_lengths = vision_widths_max + text_lengths
total_segment_lengths = total_segment_lengths.cumsum(dim=0)
text_diff = total_segment_lengths - text_lengths
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
@ -471,29 +471,25 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + total_segment_lengths[i]
im = image_position_ids + vision_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_diff[i]
for i, seq_len in enumerate(text_lengths)
] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]]
+ text_segment_lengths[i]
for i, seq_len in enumerate(text_lengths_between_vision)
]
# combine by alternating text and vision segments (text, vision, text, vision, ...)
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# the final segment is the difference between the last vision segment and the end of the input
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - ends[-1]
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)