mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix: improve and simplify get_cos_sin, refactors and cleanup get_position_ids
This commit is contained in:
parent
9eaa163239
commit
6cb0cb68b4
@ -88,22 +88,16 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
||||||
mrope_section = rope_scaling.get("mrope_section", None)
|
mrope_section = rope_scaling.get("mrope_section", None)
|
||||||
|
|
||||||
# only apply mrope if sections are provided and the rope type is mrope and a section is provided
|
|
||||||
if mrope_section is not None and rope_type == "mrope":
|
|
||||||
mrope_section = rope_scaling.get("mrope_section")
|
|
||||||
return RotaryPositionEmbeddingMultimodalSections(
|
|
||||||
inv_freq, scaling_factor, mrope_section
|
|
||||||
)
|
|
||||||
|
|
||||||
if rope_type == "linear":
|
if rope_type == "linear":
|
||||||
pass
|
pass
|
||||||
elif rope_type == "default":
|
elif rope_type == "default":
|
||||||
pass
|
pass
|
||||||
elif rope_type == "mrope":
|
elif rope_type == "mrope":
|
||||||
mrope_section = rope_scaling["mrope_section"]
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
return RotaryPositionEmbeddingMultimodalSections(
|
if mrope_section is not None:
|
||||||
inv_freq, scaling_factor, mrope_section
|
return RotaryPositionEmbeddingMultimodalSections(
|
||||||
)
|
inv_freq, scaling_factor, mrope_section
|
||||||
|
)
|
||||||
elif rope_type == "dynamic":
|
elif rope_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
return DynamicPositionRotaryEmbedding(
|
return DynamicPositionRotaryEmbedding(
|
||||||
@ -569,6 +563,12 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
self.sections = sections
|
self.sections = sections
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
self._sin_cached = None
|
self._sin_cached = None
|
||||||
|
self.section_indices = (
|
||||||
|
torch.arange(len(self.sections))
|
||||||
|
.repeat_interleave(torch.tensor(self.sections))
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.to(inv_freq.device)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -599,6 +599,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
self._sections = self.section_indices.expand(seqlen, -1, -1)
|
||||||
|
|
||||||
def get_cos_sin(
|
def get_cos_sin(
|
||||||
self,
|
self,
|
||||||
@ -607,24 +608,11 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||||
|
slen = position_ids.shape[0]
|
||||||
|
|
||||||
# access freqs for each of the 3 sections and stack them
|
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
cos_c = torch.stack(
|
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
[self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0
|
|
||||||
)
|
|
||||||
sin_c = torch.stack(
|
|
||||||
[self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# chunk based on sections
|
cos = torch.cat([cos, cos], dim=-1)
|
||||||
split_cos = torch.split(cos_c, self.sections, dim=-1)
|
sin = torch.cat([sin, sin], dim=-1)
|
||||||
split_sin = torch.split(sin_c, self.sections, dim=-1)
|
|
||||||
|
|
||||||
# for each section, select the corresponding cos/sin (0, 1, 2, ...)
|
|
||||||
cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
|
|
||||||
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
|
||||||
|
|
||||||
# double the size and add a batch dimension
|
|
||||||
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1)
|
|
||||||
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1)
|
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
@ -235,8 +235,7 @@ class Qwen2Layer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
normed_hidden_states, residual = self.input_layernorm(hidden_states)
|
||||||
normed_hidden_states, _ = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output = self.self_attn(
|
attn_output = self.self_attn(
|
||||||
@ -254,8 +253,7 @@ class Qwen2Layer(nn.Module):
|
|||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
residual = hidden_states
|
hidden_states, residual = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
|
||||||
mlp_output = self.mlp(hidden_states)
|
mlp_output = self.mlp(hidden_states)
|
||||||
hidden_states = mlp_output + residual
|
hidden_states = mlp_output + residual
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -222,10 +222,10 @@ class Qwen2VLVisionBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
norm1_out, _ = self.norm1(hidden_states)
|
norm1_out, residual = self.norm1(hidden_states)
|
||||||
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||||
hidden_states = hidden_states + attn_out
|
hidden_states = attn_out + residual
|
||||||
norm2_out, _ = self.norm2(hidden_states)
|
norm2_out, residual = self.norm2(hidden_states)
|
||||||
hidden_states = hidden_states + self.mlp(norm2_out)
|
hidden_states = hidden_states + self.mlp(norm2_out)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -410,52 +410,52 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
self.device = weights.device
|
self.device = weights.device
|
||||||
|
|
||||||
|
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
|
||||||
|
# modified to first find segments then initialize position ids for each segment
|
||||||
|
# Steps:
|
||||||
|
# locate all vision and text segments
|
||||||
|
# calculate `vision_segment_lengths` for each vision segment to be use as offset
|
||||||
|
# calculate `text_segment_lengths` for each text segment to be used as offset
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
# create position ids for each text segment
|
||||||
|
# combine all the position ids
|
||||||
|
# the final segment is the difference between the last vision segment and the end of the input
|
||||||
|
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||||
def get_position_ids(
|
def get_position_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
image_grid_thw: Optional[torch.Tensor] = None,
|
image_grid_thw: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if image_grid_thw is None:
|
if image_grid_thw is None:
|
||||||
# (batch_size, 3)
|
|
||||||
return (
|
return (
|
||||||
torch.arange(input_ids.shape[0], device=input_ids.device)
|
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.repeat(1, 3)
|
.repeat(1, 3)
|
||||||
)
|
)
|
||||||
|
|
||||||
# if image grid provided than we need to calculate the position ids
|
|
||||||
spatial_merge_size = self.spatial_merge_size
|
spatial_merge_size = self.spatial_merge_size
|
||||||
vision_start_token_id = self.vision_start_token_id
|
vision_start_token_id = self.vision_start_token_id
|
||||||
vision_end_token_id = self.vision_end_token_id
|
vision_end_token_id = self.vision_end_token_id
|
||||||
|
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
dtype = input_ids.dtype
|
dtype = input_ids.dtype
|
||||||
input_ids_len = input_ids.shape[0]
|
input_ids_len = input_ids.shape[0]
|
||||||
|
|
||||||
# capture vision segments
|
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
starts = torch.where(input_ids == vision_start_token_id)[0]
|
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||||
ends = torch.where(input_ids == vision_end_token_id)[0]
|
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||||
# ie. [[ 14, 2181], [2212, 4379]]
|
prev_vision_end = torch.cat(
|
||||||
vision_segments = torch.stack((starts, ends), dim=1)
|
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||||
# capture text lengths as the space between vision segments
|
)
|
||||||
|
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||||
prev_end = torch.cat( # shift to the left to get the previous end
|
|
||||||
[torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]]
|
|
||||||
) # ie. [0, 2181]
|
|
||||||
|
|
||||||
# text is the space between the end of one vision segment and the start of the next
|
|
||||||
text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32]
|
|
||||||
|
|
||||||
# calculate the max id from the image width for each segment
|
|
||||||
vision_widths_max = torch.cat(
|
vision_widths_max = torch.cat(
|
||||||
[
|
[
|
||||||
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||||
image_grid_thw[:-1, 2] // spatial_merge_size,
|
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
total_segment_lengths = vision_widths_max + text_lengths
|
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||||
total_segment_lengths = total_segment_lengths.cumsum(dim=0)
|
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||||
text_diff = total_segment_lengths - text_lengths
|
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||||
|
|
||||||
# create position ids for each vision segment based on the image grid
|
# create position ids for each vision segment based on the image grid
|
||||||
llm_pos_ids_list = []
|
llm_pos_ids_list = []
|
||||||
@ -471,29 +471,25 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||||
|
|
||||||
# offset by the position of the last vision segment
|
# offset by the position of the last vision segment
|
||||||
im = image_position_ids + total_segment_lengths[i]
|
im = image_position_ids + vision_segment_lengths[i]
|
||||||
llm_pos_ids_list.append(im)
|
llm_pos_ids_list.append(im)
|
||||||
|
|
||||||
# create position ids for each text segment
|
# create position ids for each text segment
|
||||||
text_ranges = [
|
text_ranges = [
|
||||||
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||||
+ text_diff[i]
|
+ text_segment_lengths[i]
|
||||||
for i, seq_len in enumerate(text_lengths)
|
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||||
] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]]
|
]
|
||||||
|
|
||||||
# combine by alternating text and vision segments (text, vision, text, vision, ...)
|
|
||||||
full_llm_pos_ids_list = [
|
full_llm_pos_ids_list = [
|
||||||
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||||
]
|
]
|
||||||
|
|
||||||
# the final segment is the difference between the last vision segment and the end of the input
|
|
||||||
max_s = full_llm_pos_ids_list[-1].max() + 1
|
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||||
final_text_len = input_ids_len - ends[-1]
|
final_text_len = input_ids_len - vision_ends[-1]
|
||||||
if final_text_len > 0:
|
if final_text_len > 0:
|
||||||
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
full_llm_pos_ids_list.append(m + max_s)
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
|
||||||
position_ids = (
|
position_ids = (
|
||||||
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user