From 44ed5efbcce0c795791a26a6de2181f0c6dbfbe8 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 18 Apr 2025 14:57:37 +0000 Subject: [PATCH] working --- .../models/flash_causal_lm.py | 4 + .../models/transformers_flash_vlm.py | 45 +++ .../models/vlm_causal_lm.py | 345 +++++++++++------- 3 files changed, 271 insertions(+), 123 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a28ef381..15f1d73d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -207,6 +207,8 @@ class FlashCausalLMBatch(Batch): # Maximum number of blocks max_blocks: int + inputs_embeds: Optional[torch.Tensor] = None + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -1896,6 +1898,8 @@ class FlashCausalLM(Model): if prefill: batch.prepare_for_prefill() + self.get_input_embeddings(batch) + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 280fa0bd..89ef2a9b 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -368,6 +368,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): image_grid_thw: Optional[torch.LongTensor] = None, pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, ): # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 @@ -377,9 +378,12 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, ) + inputs["input_ids"] = None + # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( input_ids=inputs["input_ids"], + inputs_embeds=inputs_embeds.unsqueeze(0), position_ids=inputs["position_ids"], past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object @@ -568,3 +572,44 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): inputs["cache_position"] = position_ids inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) return inputs + + def get_vision_embeds(self, pixel_values, image_sizes=None): + image_features = self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.model.config.vision_config.vision_feature_layer, + vision_feature_select_strategy=self.model.config.vision_config.vision_feature_select_strategy, + image_sizes=image_sizes, + ) + + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.model.multi_modal_projector(vision_flat) + return projected_vision_flat + + def get_input_embeds(self, input_ids, vision_embeddings=None): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + + if vision_embeddings is not None: + original_inputs_embeds_shape = inputs_embeds.shape + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( + -1 + ) + final_mask = special_image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + + final_mask_1d = final_mask[..., 0].reshape(-1) + num_tokens_to_fill = final_mask_1d.sum() + + if num_tokens_to_fill != vision_embeddings.size(0): + raise ValueError( + f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + f"but multi_modal_projector returned {vision_embeddings.size(0)}" + ) + + expanded_mask = final_mask_1d.unsqueeze(-1).expand( + -1, inputs_embeds.size(-1) + ) + inputs_embeds = inputs_embeds.masked_scatter( + expanded_mask, vision_embeddings + ) + inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) + return inputs_embeds diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 0f3183df..178f736e 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import torch from PIL import Image from io import BytesIO @@ -115,7 +116,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" if processor.image_processor.do_image_splitting: image_str *= 5 - return image_str + return image_str, IDEFICS2_FAKE_TOKEN if config.model_type == "idefics3": # TODO: implement this in a more general way n_rows = image_input["rows"][0][image_id] @@ -132,7 +133,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_token=IDEFICS3_IMAGE_TOKEN, global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, ) - return image_str + return image_str, IDEFICS3_FAKE_IMAGE_TOKEN elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) @@ -141,32 +142,32 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str logger.info, f"Found {num_features} features in image of resolution {height}x{width}", ) - return "" * num_features + return "" * num_features, "" elif config.model_type == "paligemma": - return "" * config.text_config.num_image_tokens + return "" * config.text_config.num_image_tokens, "" elif config.model_type == "qwen2_vl": grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "qwen2_5_vl": grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "gemma3": # TODO: get correct number of features via reviewing the Gemma3 architecture # and calculating the number of image tokens num_pads = 256 padding = "" * num_pads - return f"\n\n{padding}\n\n" + return f"\n\n{padding}\n\n", "" elif config.model_type == "llama4": patch_size = config.vision_config.patch_size pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) - aspect_ratios = image_input["aspect_ratios"][image_id] - image_height, image_width = image_input["pixel_values"][image_id].shape[-2:] + aspect_ratios = image_input[image_id]["aspect_ratios"][0] + image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:] num_patches_per_chunk = int( (image_height // patch_size) @@ -177,7 +178,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str aspect_ratios, num_patches_per_chunk ) - return tokens_for_this_image + return tokens_for_this_image, "<|image_start|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -244,7 +245,42 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features +def scatter_image_embeds(embeds, is_embed): + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed] = embeds + return placeholders + + +def gather_image_embeds(embeds, is_embed): + if is_embed is None: + return embeds + + gathered = embeds[is_embed] + + if len(gathered) == 0: + return None + return gathered + + +@dataclass +class ImagePositions: + offset: int + length: int + id: int + num_placeholder_tokens: int + is_embed: Optional[torch.Tensor] = None + + class VlmCausalLMBatch(FlashCausalLMBatch): + image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]] + image_positions: Optional[List[List[ImagePositions]]] + encoder_cache: Optional[List[Dict[int, torch.Tensor]]] pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @@ -276,8 +312,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch): # Process images first. We need all of them so that the processor # can make the image splits the same size. And we need the final # sizes to insert correct number of image tokens. - images = [] - for r in requests: + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + batch_image_inputs = [] + for i, r in enumerate(requests): + image_inputs = [] + batch_image_inputs.append(None) for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": @@ -292,46 +337,54 @@ class VlmCausalLMBatch(FlashCausalLMBatch): h = image.height * 2 image = image.resize((w, h)) - if config.model_type == "llava_next": - images.append(image) - elif config.model_type == "gemma3": - images.append(image) - elif config.model_type == "llama4": - images.append(image) + if config.model_type in {"llava_next", "gemma3", "llama4"}: + image = image else: - images.append([image]) + image = [image] + + pixel_values = processor.image_processor( + [image], return_tensors="pt", **kwargs + ) + + image_inputs.append(pixel_values) else: raise RuntimeError(f"Invalid chunk type {chunk_type}") - if images: - kwargs = {} - if ( - hasattr(processor, "image_processor_class") - and processor.image_processor_class == "Idefics3ImageProcessor" - ): - kwargs["return_row_col_info"] = True - - image_inputs = processor.image_processor( - images, return_tensors="pt", **kwargs - ) - else: - image_inputs = None + if len(image_inputs) > 0: + batch_image_inputs[i] = image_inputs + batch_image_positions = [] batch_tokenized_inputs = [] max_length = 0 - image_id = 0 - for r in requests: + for i, r in enumerate(requests): full_text = "" + image_tokens = [] + batch_image_positions.append(None) + image_id = 0 for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - full_text += image_text_replacement( - processor, image_inputs, config, image_id + image_text, id_token = image_text_replacement( + processor, batch_image_inputs[i], config, image_id + ) + full_text += image_text + + if config.model_type == "gemma3": + image_text = image_text.replace("\n\n", "") + + image_tokens.append( + ( + image_id, + id_token, + tokenizer(image_text, add_special_tokens=False)[ + "input_ids" + ], + ) ) image_id += 1 - # from pdb import set_trace; set_trace() + full_text = image_text_replacement_fixup(config, full_text) input_ids = tokenizer( full_text, @@ -342,7 +395,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch): max_length = max(max_length, len(input_ids)) batch_tokenized_inputs.append(input_ids) - return batch_tokenized_inputs, image_inputs + prev = 0 + image_positions = [] + for image_id, id_token, tokens in image_tokens: + id_token = tokenizer.get_vocab()[id_token] + + length = len(tokens) + index = input_ids[prev:].index(id_token) + + index = index + prev + assert ( + input_ids[index : index + length] == tokens + ), "Image token not found in input_ids" + + if config.model_type in {"llava_next", "paligemma"}: + is_embed = None + num_placeholder_tokens = length + else: + is_embed = torch.tensor(tokens) == config.image_token_index + num_placeholder_tokens = is_embed.sum().item() + + pos = ImagePositions( + offset=index, + length=length, + id=image_id, + num_placeholder_tokens=num_placeholder_tokens, + is_embed=is_embed, + ) + + image_positions.append(pos) + prev = index + length + + if len(image_positions) > 0: + batch_image_positions[i] = image_positions + + return batch_tokenized_inputs, batch_image_inputs, batch_image_positions @classmethod def from_pb_processor( @@ -354,27 +441,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "VlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config + batch_tokenized_inputs, image_inputs, image_positions = ( + cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config) ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - if "image_grid_thw" in image_inputs: - batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) - else: - batch.image_grid_thw = None - else: + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = [{} for _ in range(len(pb.requests))] + if len(image_inputs): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None @@ -394,58 +468,77 @@ class VlmCausalLMBatch(FlashCausalLMBatch): r, cache_length, input_length, - prompt_length, request_prefilling, ) in enumerate( zip( self.requests, self.cache_lengths, self.input_lengths, - self.prompt_lengths, self.prefilling_mask, ) ): - if not request_prefilling: + if not request_prefilling or self.image_positions[i] is None: continue - for mm_inputs in batch_mm_inputs[i]: - for j, mm_input in enumerate(mm_input): - image_id = mm_input.id - pixel_values = self.all_pixel_values[i][j].pixel_values + for j, image_position in enumerate(self.image_positions[i]): + image_id = image_position.id + pixel_values = self.image_inputs[i][j] - start_pos = mm_input.offset - length = mm_input.length - num_placeholder_tokens = mm_input.num_placeholder_tokens + start_pos = image_position.offset + length = image_position.length - if start_pos >= cache_length + input_length: - # No encoder input required at this step - break - if start_pos + length <= cache_length: - # The encode input is already processed - continue - - self.has_image = True + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + self.has_image = True + + if image_id not in self.encoder_cache[i]: + self.scheduled_image_input.append((i, image_position)) + scheduled_image_pixel_values.append(pixel_values) - if image_id not in self.encoder_cache[i][image_id]: - self.scheduled_image_input.append((i, mm_input)) - scheduled_image_pixel_values.append(pixel_values) - if self.has_image and len(scheduled_image_pixel_values): - self.pixel_values = torch.cat([scheduled_image_pixel_values], dim=0).to(device) + self.pixel_values = torch.cat( + [d["pixel_values"] for d in scheduled_image_pixel_values], dim=0 + ).to(device) + if "pixel_attention_mask" in scheduled_image_pixel_values[0]: + self.pixel_attention_mask = torch.cat( + [d["pixel_attention_mask"] for d in scheduled_image_pixel_values], + dim=0, + ).to(device) + + if "image_sizes" in scheduled_image_pixel_values[0]: + self.image_sizes = torch.cat( + [d["image_sizes"] for d in scheduled_image_pixel_values], dim=0 + ).to(device) + + if "image_grid_thw" in scheduled_image_pixel_values[0]: + self.image_grid_thw = torch.cat( + [d["image_grid_thw"] for d in scheduled_image_pixel_values], dim=0 + ).to(device) + else: + self.pixel_values = None + self.pixel_attention_mask = None + self.image_sizes = None + self.image_grid_thw = None def update_encoder_cache(self, encoder_outputs): prev = 0 for i, input in self.scheduled_image_input: length = input.num_placeholder_tokens - output = encoder_outputs[prev:length] - batch.encoder_cache[i][image_id] = self.scatter_image_embed(output, input.is_embed) - - prev = length - + self.encoder_cache[i][input.id] = scatter_image_embeds( + encoder_outputs[prev : prev + length], input.is_embed + ) + + prev = prev + length + def get_mm_embeddings(self): device = self.input_ids.device - + mm_embeds = [] for i, ( r, cache_length, @@ -461,51 +554,43 @@ class VlmCausalLMBatch(FlashCausalLMBatch): self.prefilling_mask, ) ): - if not request_prefilling: + if not request_prefilling or self.image_positions[i] is None: continue - for mm_inputs in batch_mm_inputs[i]: - for j, mm_input in enumerate(mm_input): - image_id = mm_input.id - pixel_values = self.all_pixel_values[i][j].pixel_values + for j, image_position in enumerate(self.image_positions[i]): + image_id = image_position.id - start_pos = mm_input.offset - length = mm_input.length - num_placeholder_tokens = mm_input.num_placeholder_tokens + start_pos = image_position.offset + length = image_position.length - if start_pos >= cache_length + input_length: - # No encoder input required at this step - break - if start_pos + length <= cache_length: - # The encode input is already processed - continue - - self.has_image = True + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue - if image_id not in self.encoder_cache[i][image_id]: - self.scheduled_image_input.append(mm_input) - scheduled_image_pixel_values.append(pixel_values) - + start_idx = max(cache_length - start_pos, 0) + end_idx = min(cache_length - start_pos + input_length, length) - start_idx = max(cache_length - start_pos, 0) - end_idx = min( - cache_length - start_pos + input_length, - length) - - encoder_output = self.encoder_cache[i][image_id] + assert ( + image_id in self.encoder_cache[i] + ), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}" + encoder_output = self.encoder_cache[i][image_id] - is_embed = pos_info.is_embed - if is_embed is not None: - is_embed = is_embed[start_idx:end_idx] + is_embed = image_position.is_embed + if is_embed is not None: + is_embed = is_embed[start_idx:end_idx] - mm_embeds_item = gather_mm_embeds( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) - mm_embeds.append(mm_embeds_item) + mm_embeds_item = gather_image_embeds( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + mm_embeds.append(mm_embeds_item) return torch.cat(mm_embeds, dim=0).to(device) + class VlmCausalLM(FlashCausalLM): def __init__( self, @@ -544,12 +629,24 @@ class VlmCausalLM(FlashCausalLM): def get_mm_embeddings(self, batch): if batch.pixel_values is not None: - encoder_outputs = self.model.get_mm_embeddings(batch.pixel_values) - + encoder_outputs = self.get_vision_embeds(batch.pixel_values) batch.update_encoder_cache(encoder_outputs) - + batch.pixel_values = None return batch.get_mm_embeddings() + def get_input_embeddings(self, batch): + if batch.has_image: + vision_embeddings = self.get_mm_embeddings(batch) + batch.has_image = False + else: + vision_embeddings = None + + inputs_embeds = self.get_input_embeds( + batch.input_ids, vision_embeddings=vision_embeddings + ) + + batch.inputs_embeds = inputs_embeds + def forward( self, batch: VlmCausalLMBatch, @@ -600,6 +697,7 @@ class VlmCausalLM(FlashCausalLM): position_ids = new_position_ids else: input_ids = batch.input_ids + inputs_embeds = batch.inputs_embeds position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache @@ -659,6 +757,7 @@ class VlmCausalLM(FlashCausalLM): ) logits, speculative_logits = self.model.forward( input_ids=input_ids, + inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache,