mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-28 19:00:20 +00:00
fix: simplify get position ids and remove usused vision config
This commit is contained in:
parent
6893eb3834
commit
68e3ee8e79
@ -230,14 +230,7 @@ struct QuantizationConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct VisionConfig {
|
struct VisionConfig {}
|
||||||
depth: Option<usize>,
|
|
||||||
embed_dim: Option<usize>,
|
|
||||||
mlp_ratio: Option<usize>,
|
|
||||||
in_chans: Option<usize>,
|
|
||||||
patch_size: Option<usize>,
|
|
||||||
temporal_patch_size: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
|
@ -382,6 +382,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
config.rope_scaling.update({"rope_type": "mrope"})
|
config.rope_scaling.update({"rope_type": "mrope"})
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.vision_start_token_id = config.vision_start_token_id
|
self.vision_start_token_id = config.vision_start_token_id
|
||||||
|
self.vision_end_token_id = config.vision_end_token_id
|
||||||
self.image_token_id = config.image_token_id
|
self.image_token_id = config.image_token_id
|
||||||
self.video_token_id = config.video_token_id
|
self.video_token_id = config.video_token_id
|
||||||
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||||
@ -411,98 +412,112 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def get_position_ids(
|
def get_position_ids(
|
||||||
self,
|
self,
|
||||||
batch_input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: torch.Tensor,
|
||||||
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
) -> torch.Tensor:
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
if batch_input_ids.dim() == 1:
|
|
||||||
batch_input_ids = batch_input_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
|
# TODO: avoid the early return and extra work in a more efficient way
|
||||||
|
if image_grid_thw is not None:
|
||||||
|
|
||||||
|
if input_ids.dim() == 1:
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
position_ids = torch.ones(
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
input_ids.shape[0],
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.repeat(3, input_ids.shape[0], 1)
|
||||||
|
)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
# if image grid provided than we need to calculate the position ids
|
||||||
|
|
||||||
|
spatial_merge_size = self.spatial_merge_size
|
||||||
|
vision_start_token_id = self.vision_start_token_id
|
||||||
|
vision_end_token_id = self.vision_end_token_id
|
||||||
|
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
input_ids_len = input_ids.shape[0]
|
||||||
position_ids = torch.ones(
|
position_ids = torch.ones(
|
||||||
3,
|
3,
|
||||||
batch_input_ids.shape[0],
|
input_ids_len,
|
||||||
batch_input_ids.shape[1],
|
dtype=dtype,
|
||||||
dtype=batch_input_ids.dtype,
|
device=device,
|
||||||
device=batch_input_ids.device,
|
|
||||||
)
|
)
|
||||||
d = batch_input_ids.device
|
|
||||||
if image_grid_thw is not None:
|
|
||||||
image_index = 0
|
|
||||||
llm_pos_ids_list = []
|
|
||||||
|
|
||||||
for i, input_ids in enumerate(batch_input_ids):
|
# capture vision segments
|
||||||
vision_start_indices = torch.argwhere(
|
starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
input_ids == self.vision_start_token_id
|
ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||||
).squeeze(1)
|
# ie. [[ 14, 2181], [2212, 4379]]
|
||||||
vision_tokens = input_ids[vision_start_indices + 1]
|
vision_segments = torch.stack((starts, ends), dim=1)
|
||||||
# only copy the sum of the image tokens GPU<->CPU
|
# capture text lengths as the space between vision segments
|
||||||
image_count = (vision_tokens == self.image_token_id).sum().item()
|
|
||||||
|
|
||||||
current_pos = 0
|
prev_end = torch.cat( # shift to the left to get the previous end
|
||||||
for _ in range(image_count):
|
[torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]]
|
||||||
# copy the value position of the next image token from GPU<->CPU
|
) # ie. [0, 2181]
|
||||||
next_image_pos = (
|
|
||||||
(input_ids[current_pos:] == self.image_token_id)
|
|
||||||
.nonzero()[0]
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
|
||||||
time_steps, height, width = image_grid_thw[image_index].clone()
|
|
||||||
height //= self.spatial_merge_size
|
|
||||||
width //= self.spatial_merge_size
|
|
||||||
|
|
||||||
# calculate the length of the text and image tokens
|
# text is the space between the end of one vision segment and the start of the next
|
||||||
text_length = next_image_pos
|
text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32]
|
||||||
start_idx = (
|
|
||||||
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# text position ids
|
# calculate the max id from the image width for each segment
|
||||||
text_pos_ids = torch.arange(text_length, device=d)
|
vision_widths_max = torch.cat(
|
||||||
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
[
|
||||||
llm_pos_ids_list.append(text_pos_ids)
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||||
|
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
total_segment_lengths = vision_widths_max + text_lengths
|
||||||
|
total_segment_lengths = total_segment_lengths.cumsum(dim=0)
|
||||||
|
text_diff = total_segment_lengths - text_lengths
|
||||||
|
|
||||||
# image position ids
|
# create position ids for each vision segment based on the image grid
|
||||||
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
llm_pos_ids_list = []
|
||||||
height * width
|
for i, _ in enumerate(vision_segments):
|
||||||
)
|
t, h, w = (
|
||||||
h_indices = (
|
image_grid_thw[i][0],
|
||||||
torch.arange(height, device=d)
|
image_grid_thw[i][1] // spatial_merge_size,
|
||||||
.repeat_interleave(width)
|
image_grid_thw[i][2] // spatial_merge_size,
|
||||||
.repeat(time_steps)
|
|
||||||
)
|
|
||||||
w_indices = torch.arange(width, device=d).repeat(
|
|
||||||
height * time_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
image_pos_ids = (
|
|
||||||
torch.stack([t_indices, h_indices, w_indices])
|
|
||||||
+ text_length
|
|
||||||
+ start_idx
|
|
||||||
)
|
|
||||||
llm_pos_ids_list.append(image_pos_ids)
|
|
||||||
|
|
||||||
current_pos += next_image_pos + time_steps * height * width
|
|
||||||
image_index += 1
|
|
||||||
|
|
||||||
if current_pos < batch_input_ids.size(1):
|
|
||||||
st_idx = (
|
|
||||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
||||||
)
|
|
||||||
text_len = batch_input_ids.size(1) - current_pos
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
||||||
position_ids[:, i, :] = llm_positions.to(position_ids.device)
|
|
||||||
else:
|
|
||||||
position_ids = (
|
|
||||||
torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device)
|
|
||||||
.view(1, 1, -1)
|
|
||||||
.repeat(3, batch_input_ids.shape[0], 1)
|
|
||||||
)
|
)
|
||||||
return position_ids
|
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||||
|
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||||
|
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||||
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||||
|
|
||||||
|
# offset by the position of the last vision segment
|
||||||
|
im = image_position_ids + total_segment_lengths[i]
|
||||||
|
llm_pos_ids_list.append(im)
|
||||||
|
|
||||||
|
# create position ids for each text segment
|
||||||
|
text_ranges = [
|
||||||
|
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
+ text_diff[i]
|
||||||
|
for i, seq_len in enumerate(text_lengths)
|
||||||
|
] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]]
|
||||||
|
|
||||||
|
# combine by alternating text and vision segments (text, vision, text, vision, ...)
|
||||||
|
full_llm_pos_ids_list = [
|
||||||
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||||
|
]
|
||||||
|
|
||||||
|
# the final segment is the difference between the last vision segment and the end of the input
|
||||||
|
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||||
|
final_text_len = input_ids_len - ends[-1]
|
||||||
|
if final_text_len > 0:
|
||||||
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
|
# combine all the segments and reshape to (3, input_ids_len)
|
||||||
|
llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
position_ids[..., :] = llm_positions.to(position_ids.device)
|
||||||
|
# TODO: avoid the extra dimension when updating the consumer of this function
|
||||||
|
return position_ids.unsqueeze(1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user