diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 6391f9eb..05ed0202 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -230,14 +230,7 @@ struct QuantizationConfig { } #[derive(Debug, Deserialize)] -struct VisionConfig { - depth: Option, - embed_dim: Option, - mlp_ratio: Option, - in_chans: Option, - patch_size: Option, - temporal_patch_size: Option, -} +struct VisionConfig {} #[derive(Debug, Deserialize)] struct Config { diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 7e296b42..fdc426bc 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -382,6 +382,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): config.rope_scaling.update({"rope_type": "mrope"}) self.hidden_size = config.hidden_size self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id self.image_token_id = config.image_token_id self.video_token_id = config.video_token_id self.spatial_merge_size = config.vision_config.spatial_merge_size @@ -411,98 +412,112 @@ class Qwen2VLForConditionalGeneration(nn.Module): def get_position_ids( self, - batch_input_ids: torch.Tensor, - image_grid_thw: Optional[torch.LongTensor] = None, - # video_grid_thw is not implemented yet as we do not accept video inputs at the moment - ) -> Tuple[torch.Tensor, torch.Tensor]: - if batch_input_ids.dim() == 1: - batch_input_ids = batch_input_ids.unsqueeze(0) + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + ) -> torch.Tensor: + # TODO: avoid the early return and extra work in a more efficient way + if image_grid_thw is not None: + + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + + position_ids = torch.ones( + 3, + 1, + input_ids.shape[0], + dtype=input_ids.dtype, + device=input_ids.device, + ) + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .repeat(3, input_ids.shape[0], 1) + ) + return position_ids + + # if image grid provided than we need to calculate the position ids + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] position_ids = torch.ones( 3, - batch_input_ids.shape[0], - batch_input_ids.shape[1], - dtype=batch_input_ids.dtype, - device=batch_input_ids.device, + input_ids_len, + dtype=dtype, + device=device, ) - d = batch_input_ids.device - if image_grid_thw is not None: - image_index = 0 - llm_pos_ids_list = [] - for i, input_ids in enumerate(batch_input_ids): - vision_start_indices = torch.argwhere( - input_ids == self.vision_start_token_id - ).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - # only copy the sum of the image tokens GPU<->CPU - image_count = (vision_tokens == self.image_token_id).sum().item() + # capture vision segments + starts = torch.where(input_ids == vision_start_token_id)[0] + ends = torch.where(input_ids == vision_end_token_id)[0] + # ie. [[ 14, 2181], [2212, 4379]] + vision_segments = torch.stack((starts, ends), dim=1) + # capture text lengths as the space between vision segments - current_pos = 0 - for _ in range(image_count): - # copy the value position of the next image token from GPU<->CPU - next_image_pos = ( - (input_ids[current_pos:] == self.image_token_id) - .nonzero()[0] - .item() - ) - # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop - time_steps, height, width = image_grid_thw[image_index].clone() - height //= self.spatial_merge_size - width //= self.spatial_merge_size + prev_end = torch.cat( # shift to the left to get the previous end + [torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]] + ) # ie. [0, 2181] - # calculate the length of the text and image tokens - text_length = next_image_pos - start_idx = ( - llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - ) + # text is the space between the end of one vision segment and the start of the next + text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32] - # text position ids - text_pos_ids = torch.arange(text_length, device=d) - text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx - llm_pos_ids_list.append(text_pos_ids) + # calculate the max id from the image width for each segment + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + total_segment_lengths = vision_widths_max + text_lengths + total_segment_lengths = total_segment_lengths.cumsum(dim=0) + text_diff = total_segment_lengths - text_lengths - # image position ids - t_indices = torch.arange(time_steps, device=d).repeat_interleave( - height * width - ) - h_indices = ( - torch.arange(height, device=d) - .repeat_interleave(width) - .repeat(time_steps) - ) - w_indices = torch.arange(width, device=d).repeat( - height * time_steps - ) - - image_pos_ids = ( - torch.stack([t_indices, h_indices, w_indices]) - + text_length - + start_idx - ) - llm_pos_ids_list.append(image_pos_ids) - - current_pos += next_image_pos + time_steps * height * width - image_index += 1 - - if current_pos < batch_input_ids.size(1): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - text_len = batch_input_ids.size(1) - current_pos - llm_pos_ids_list.append( - torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[:, i, :] = llm_positions.to(position_ids.device) - else: - position_ids = ( - torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device) - .view(1, 1, -1) - .repeat(3, batch_input_ids.shape[0], 1) + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, ) - return position_ids + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + total_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_diff[i] + for i, seq_len in enumerate(text_lengths) + ] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]] + + # combine by alternating text and vision segments (text, vision, text, vision, ...) + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + + # the final segment is the difference between the last vision segment and the end of the input + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + # combine all the segments and reshape to (3, input_ids_len) + llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., :] = llm_positions.to(position_ids.device) + # TODO: avoid the extra dimension when updating the consumer of this function + return position_ids.unsqueeze(1) def forward( self,