diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 24ecd2ad..545cddd0 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -114,8 +114,8 @@ impl Client { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let mut inputs = String::new(); - inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)"); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); requests.push(Request { id: 0, diff --git a/router/src/config.rs b/router/src/config.rs index 4ee4704f..8c9fa33f 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -66,7 +66,8 @@ impl LlavaNext { let (num_patch_height, num_patch_width) = get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); // Ceil - let height_of_patch = (height * npatches + width - 1) / width; + // TODO Very odd artifact when the rounding is super close + let height_of_patch = (height * npatches + width - 10) / width; let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width; // They are only added after width let newline_features = height_of_patch * num_patch_width; @@ -166,5 +167,7 @@ mod test { assert_eq!(slots, 2732); let slots = config.get_number_of_features(1024, 899); assert_eq!(slots, 3320); + let slots = config.get_number_of_features(1067, 1600); + assert_eq!(slots, 2144); } } diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 95a4d476..0029e3b2 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -739,6 +739,8 @@ class Idefics2ForConditionalGeneration(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 08ac9fcf..14bf19e1 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -170,6 +170,8 @@ class LlavaNextForConditionalGeneration(nn.Module): prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, + # Unused for this model + pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, ): inputs_embeds = self.language_model.embed_tokens(input_ids) diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py index f759300d..e831af89 100644 --- a/server/text_generation_server/models/idefics2.py +++ b/server/text_generation_server/models/idefics2.py @@ -26,6 +26,8 @@ class Idefics2(VlmCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. size={"longest_edge": 448, "shortest_edge": 378}, ) super().__init__( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ab87c5c7..1e60ab1f 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -64,6 +64,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size +def image_text_replacement(image_input, config) -> str: + if config.model_type == "idefics2": + # TODO technically depends on image splitting which is not implemented. + num_features = 320 + return ( + "" + + "" * num_features + + "" + ) + elif config.model_type == "llava_next": + height, width = image_input["image_sizes"][0] + num_features = get_number_of_features(height, width, config) + from loguru import logger + + logger.info(f"Found {num_features} in image of resolution {height}x{width}") + return "" * num_features + else: + raise RuntimeError(f"Unknown config {config.model_type} for multimodal") + + def get_number_of_features(height: int, width: int, config) -> int: # From config # Hardcoded for CLIP for now @@ -82,7 +102,7 @@ def get_number_of_features(height: int, width: int, config) -> int: image_size, ) - height_of_patch = math.ceil(height / width * npatches) + height_of_patch = (height * npatches + width - 10) // width unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width # They are only added after width @@ -99,12 +119,9 @@ def load_data_uri(image_uri: str) -> Image.Image: return image -# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}" -# assert get_number_of_features(640, 640) == 2928 - - class VlmCausalLMBatch(FlashMistralBatch): pixel_values: Optional[List[torch.Tensor]] + pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @classmethod @@ -112,6 +129,7 @@ class VlmCausalLMBatch(FlashMistralBatch): def concatenate(cls, batches): batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch.pixel_values = None + batch.pixel_attention_mask = None batch.image_sizes = None return batch @@ -119,6 +137,7 @@ class VlmCausalLMBatch(FlashMistralBatch): def filter(self, request_ids: List[int]): batch = super().filter(request_ids) batch.pixel_values = None + batch.pixel_attention_mask = None batch.image_sizes = None return batch @@ -147,11 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch): "Cannot process input image not starting with data:" ) image_input = processor.image_processor(image, return_tensors="pt") - # import ipdb;ipdb.set_trace() - # height, width = image_input["image_sizes"][0] - # num_features = get_number_of_features(height, width, config) - num_features = 320 - full_text += "" * num_features + full_text += image_text_replacement(image_input, config) image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk['type']}") @@ -163,15 +178,21 @@ class VlmCausalLMBatch(FlashMistralBatch): batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] if image_inputs: - image_inputs = { + image_input = image_inputs[0] + new_image_inputs = { "pixel_values": torch.cat( [img["pixel_values"] for img in image_inputs], dim=0 ), - "pixel_attention_mask": torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ), - # "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]), } + if "pixel_attention_mask" in image_input: + new_image_inputs["pixel_attention_mask"] = torch.cat( + [img["pixel_attention_mask"] for img in image_inputs], dim=0 + ) + if "image_sizes" in image_input: + new_image_inputs["image_sizes"] = torch.cat( + [img["image_sizes"] for img in image_inputs], dim=0 + ) + image_inputs = new_image_inputs else: image_inputs = None return batch_tokenized_inputs, image_inputs @@ -192,14 +213,20 @@ class VlmCausalLMBatch(FlashMistralBatch): batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: batch.pixel_values = image_inputs["pixel_values"].to(device=device) - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - # batch.image_sizes = image_inputs["image_sizes"].to(device=device) + if "pixel_attention_mask" in image_inputs: + batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( + device=device + ) + else: + batch.pixel_attention_mask = None + if "image_sizes" in image_inputs: + batch.image_sizes = image_inputs["image_sizes"].to(device=device) + else: + batch.image_sizes = None else: batch.pixel_values = None batch.pixel_attention_mask = None - # batch.image_sizes = None + batch.image_sizes = None return batch @@ -291,7 +318,7 @@ class VlmCausalLM(BaseFlashMistral): lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, - # image_sizes=batch.image_sizes, + image_sizes=batch.image_sizes, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -299,8 +326,8 @@ class VlmCausalLM(BaseFlashMistral): batch.pixel_values = None if batch.pixel_attention_mask is not None: batch.pixel_attention_mask = None - # if batch.image_sizes is not None: - # batch.image_sizes = None + if batch.image_sizes is not None: + batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph