diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index febf28da..13f12ef1 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -144,7 +144,7 @@ class TensorParallelColumnLinear(SuperLayer): num_key_value_heads=num_key_value_heads, ) if bias: - bias = weights.get_tensor(f"{prefix}.bias") + raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None linear = get_linear(weight, bias) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 7ae43256..5fe39bc9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -130,28 +130,23 @@ class Qwen2Attention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - # TODO: correctly handle the multimodal case - if False: - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - else: - # multimodal rotary - unsqueeze_dim = 1 - mrope_section = self.mrope_section * 2 - cos = torch.cat( - [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) - sin = torch.cat( - [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) + _query = query.clone() + _cos = cos.clone() + _sin = sin.clone() - _query = query.transpose(0, 1).unsqueeze(0) - _key = torch.select(kv, dim=1, index=0).transpose(0, 1).unsqueeze(0) - q_embed = (_query * cos) + (rotate_half(_query) * sin) - k_embed = (_key * cos) + (rotate_half(_key) * sin) - query = q_embed.squeeze(0).transpose(0, 1) - kv[:, 0] = k_embed.squeeze(0).transpose(0, 1) + self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin) + + _cos = torch.cat((_cos, _cos), dim=-1) + _sin = torch.cat((_sin, _sin), dim=-1) + q_emb = (_query * _cos).reshape(2, 1, -1) + ( + rotate_half(_query) * _sin + ).reshape(2, 1, -1) + k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + ( + rotate_half(torch.select(kv, dim=1, index=0)) * _sin + ).reshape(2, 1, -1) + + query = q_emb.reshape(-1, self.num_heads, self.head_size) + kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] @@ -299,9 +294,6 @@ class Qwen2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) self.layers = nn.ModuleList( [ Qwen2Layer( @@ -325,7 +317,7 @@ class Qwen2Model(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -335,25 +327,12 @@ class Qwen2Model(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], - inputs_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + hidden_states = inputs_embeds - # if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids - if inputs_embeds is not None: - hidden_states = inputs_embeds.squeeze(0) - else: - hidden_states = self.embed_tokens(input_ids) - - # Get rotary cos and sin for this forward - # Avoid to index in each layer - # TODO: fix how N-D position_ids are handled - - if position_ids.dim() == 2: - position_ids = position_ids.unsqueeze(0) - - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack( - position_ids, true_max_s, hidden_states.dtype + # TODO: ensure we are getting the correct positional embeddings + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids[0, 0, :], true_max_s, hidden_states.dtype ) residual = None @@ -393,6 +372,11 @@ class Qwen2ForCausalLM(torch.nn.Module): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.max_past = config.sliding_window self.max_past_tensor = ( torch.tensor(config.sliding_window, device=weights.device) @@ -423,8 +407,10 @@ class Qwen2ForCausalLM(torch.nn.Module): # kernel requires the true values seqlen = seqlen.clamp(max=self.max_past_tensor) + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index ac66695a..2eb7b978 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -35,6 +35,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, + TensorParallelEmbedding, FastLinear, ) from text_generation_server.layers.attention import ( @@ -78,11 +79,11 @@ class Qwen2VLSdpaAttention(nn.Module): config, prefix=f"{prefix}.qkv", weights=weights, - bias=True, + bias=False, num_heads=self.num_heads, num_key_value_heads=self.num_heads, ) - + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) self.proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.proj", @@ -285,7 +286,6 @@ class Qwen2VisionModel(nn.Module): self, pixel_values: torch.Tensor, aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # reshape the input tensor for processing @@ -361,7 +361,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): self.image_token_id = config.image_token_id self.video_token_id = config.video_token_id self.spatial_merge_size = config.vision_config.spatial_merge_size - + self.embed_tokens = TensorParallelEmbedding( + prefix=f"model.embed_tokens", weights=weights + ) self.visual = Qwen2VisionModel( prefix="visual", config=config.vision_config, weights=weights ) @@ -371,6 +373,93 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) self.device = weights.device + def get_position_ids( + self, + batch_input_ids: torch.Tensor, + image_grid_thw: Optional[torch.LongTensor], + # video_grid_thw is not implemented yet as we do not accept video inputs at the moment + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = torch.ones( + 3, + batch_input_ids.shape[0], + batch_input_ids.shape[1], + dtype=batch_input_ids.dtype, + device=batch_input_ids.device, + ) + d = batch_input_ids.device + if image_grid_thw is not None: + image_index = 0 + llm_pos_ids_list = [] + + for i, input_ids in enumerate(batch_input_ids): + vision_start_indices = torch.argwhere( + input_ids == self.vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + # only copy the sum of the image tokens GPU<->CPU + image_count = (vision_tokens == self.image_token_id).sum().item() + + current_pos = 0 + for _ in range(image_count): + # copy the value position of the next image token from GPU<->CPU + next_image_pos = ( + (input_ids[current_pos:] == self.image_token_id) + .nonzero()[0] + .item() + ) + # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop + time_steps, height, width = image_grid_thw[image_index] + height //= self.spatial_merge_size + width //= self.spatial_merge_size + + # calculate the length of the text and image tokens + text_length = next_image_pos - current_pos + start_idx = ( + llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + ) + + # text position ids + text_pos_ids = torch.arange(text_length, device=d) + text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx + llm_pos_ids_list.append(text_pos_ids) + + # image position ids + t_indices = torch.arange(time_steps, device=d).repeat_interleave( + height * width + ) + h_indices = ( + torch.arange(height, device=d) + .repeat_interleave(width) + .repeat(time_steps) + ) + w_indices = torch.arange(width, device=d).repeat( + height * time_steps + ) + + image_pos_ids = ( + torch.stack([t_indices, h_indices, w_indices]) + + text_length + + start_idx + ) + llm_pos_ids_list.append(image_pos_ids) + + current_pos = next_image_pos + time_steps * height * width + image_index += 1 + + if current_pos < batch_input_ids.size(1): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + text_len = batch_input_ids.size(1) - current_pos + llm_pos_ids_list.append( + torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[:, i, :] = llm_positions.to(position_ids.device) + + return position_ids + def forward( self, input_ids: torch.Tensor, @@ -392,147 +481,17 @@ class Qwen2VLForConditionalGeneration(nn.Module): cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - # make an attention_mask that is (batch_size, sequence_length) - attention_mask = torch.ones_like( - input_ids, dtype=torch.bool, device=input_ids.device - ) - inputs_embeds = self.text_model.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) # apply the visual model to the pixel values if they are provided if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None: image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = ( - (input_ids == self.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) - # input embeddings are masked with image embeddings - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - # handle the position_ids taking the multimodal inputs into account - mrope_position_deltas = [] - if image_grid_thw is not None or video_grid_thw is not None: - total_input_ids = input_ids - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - - for i, input_ids in enumerate(total_input_ids): - if attention_mask is not None: - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere( - input_ids == self.vision_start_token_id - ).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - # determine the number of images and videos in the input - image_nums = (vision_tokens == self.image_token_id).sum() - video_nums = (vision_tokens == self.video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - # process each input based on it's token type and grid size - for _ in range(image_nums + video_nums): - if self.image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(self.image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if self.video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(self.video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // self.spatial_merge_size, - w.item() // self.spatial_merge_size, - ) - text_len = ed - st - - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( - position_ids.device - ) - mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) - ) - mrope_position_deltas = torch.tensor( - mrope_position_deltas, device=input_ids.device - ).unsqueeze(1) + inputs_embeds[input_ids == self.image_token_id] = image_embeds + position_ids = self.get_position_ids(input_ids, image_grid_thw) outputs = self.text_model( - input_ids=input_ids, + inputs_embeds=inputs_embeds.squeeze(0), position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, @@ -542,8 +501,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): max_s=max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, ) logits = self.lm_head(outputs)