diff --git a/integration-tests/models/test_idefics3.py b/integration-tests/models/test_idefics3.py new file mode 100644 index 00000000..8bcaeda2 --- /dev/null +++ b/integration-tests/models/test_idefics3.py @@ -0,0 +1,104 @@ +import pytest +import base64 + + +# TODO fix the server parsser to count inline image tokens correctly +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture(scope="module") +def flash_idefics3_next_handle(launcher): + with launcher( + "HuggingFaceM4/Idefics3-8B-Llama3", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_idefics3_next(flash_idefics3_next_handle): + await flash_idefics3_next_handle.health(300) + return flash_idefics3_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_simple(flash_idefics3_next, response_snapshot): + chicken = get_chicken() + response = await flash_idefics3_next.generate( + f"User:![]({chicken})Write me a short story \nAssistant:", + max_new_tokens=10, + ) + assert ( + response.generated_text == " A chicken is sitting on a pile of money." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_two_images(flash_idefics3_next, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await flash_idefics3_next.generate( + f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", + max_new_tokens=20, + ) + assert ( + response.generated_text + == " The cow is standing on the beach and the chicken is sitting on a pile of money." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 19 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_all_params(flash_idefics3_next, response_snapshot): + response = await flash_idefics3_next.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_load( + flash_idefics3_next, generate_load, response_snapshot +): + chicken = get_chicken() + responses = await generate_load( + flash_idefics3_next, + f"User:![]({chicken})Write me a short story \nAssistant:", + max_new_tokens=10, + n=4, + ) + generated_texts = [r.generated_text for r in responses] + assert generated_texts[0] == " A chicken is sitting on a pile of money." + assert len(generated_texts) == 4 + assert all([r.generated_text == generated_texts[0] for r in responses]) + + assert responses == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index 5d07a293..8510d356 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -110,6 +110,13 @@ pub struct ClipVisionModel { patch_size: usize, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Idefics3 { + pub(crate) vision_encoder_max_image_size: usize, + pub(crate) image_seq_len: usize, +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Idefics2 {} @@ -178,6 +185,7 @@ pub enum Config { Idefics, Mllama, Idefics2(Idefics2), + Idefics3(Idefics3), Ssm, GptBigcode, Granite, diff --git a/router/src/lib.rs b/router/src/lib.rs index 84e9bc48..21c45241 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -170,6 +170,7 @@ impl TokenizerConfigToken { #[serde(tag = "processor_class")] pub enum HubPreprocessorConfig { Idefics2Processor(Idefics2Preprocessor), + Idefics3Processor(Idefics2Preprocessor), } impl HubPreprocessorConfig { diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608..c8b68960 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -151,6 +151,7 @@ try: ) from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, + Idefics3ForConditionalGeneration, ) from text_generation_server.models.custom_modeling.qwen2_vl import ( Qwen2VLForConditionalGeneration, @@ -188,6 +189,12 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", "multimodal": True, } + IDEFICS3 = { + "type": "idefics3", + "name": "Idefics 3", + "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3", + "multimodal": True, + } LLAVA_NEXT = { "type": "llava_next", "name": "Llava Next (1.6)", @@ -1253,6 +1260,23 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == IDEFICS3: + if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=Idefics3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: if FLASH_ATTENTION: return VlmCausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2c007d15..3e565109 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -515,9 +515,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" if not prefix else f"{prefix}.model.layers.0" - ), + prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"), config=config, weights=weights, ) @@ -564,7 +562,7 @@ class FlashLlamaModel(torch.nn.Module): prefix=( f"model.layers.{last_layer_id}" if not prefix - else f"{prefix}.model.layers.{last_layer_id}" + else f"{prefix}.layers.{last_layer_id}" ), config=config, weights=weights, @@ -572,7 +570,7 @@ class FlashLlamaModel(torch.nn.Module): ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", + prefix="model.norm" if not prefix else f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -635,9 +633,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" - if not prefix - else f"{prefix}.model.embed_tokens" + "model.embed_tokens" if not prefix else f"{prefix}.embed_tokens" ), weights=weights, ) @@ -655,7 +651,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", + prefix=suffix, weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 923123d6..55cfb9e6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -679,6 +679,281 @@ class Idefics2Connector(nn.Module): return image_hidden_states +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = TensorParallelRowLinear.load( + prefix=f"{prefix}.modality_projection.proj", + config=config, + weights=weights, + bias=False, + ) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states, attention_mask): + print(image_hidden_states.device, self.modality_projection.linear.weight.device) + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class PerceiverConfig: + def __init__(self, config_dict): + self._name_or_path = config_dict.get("_name_or_path", "") + self.add_cross_attention = config_dict.get("add_cross_attention", False) + self.architectures = config_dict.get("architectures", None) + self.attention_dropout = config_dict.get("attention_dropout", 0.0) + self.bad_words_ids = config_dict.get("bad_words_ids", None) + self.begin_suppress_tokens = config_dict.get("begin_suppress_tokens", None) + self.bos_token_id = config_dict.get("bos_token_id", None) + self.chunk_size_feed_forward = config_dict.get("chunk_size_feed_forward", 0) + self.cross_attention_hidden_size = config_dict.get( + "cross_attention_hidden_size", None + ) + self.decoder_start_token_id = config_dict.get("decoder_start_token_id", None) + self.diversity_penalty = config_dict.get("diversity_penalty", 0.0) + self.do_sample = config_dict.get("do_sample", False) + self.early_stopping = config_dict.get("early_stopping", False) + self.encoder_no_repeat_ngram_size = config_dict.get( + "encoder_no_repeat_ngram_size", 0 + ) + self.eos_token_id = config_dict.get("eos_token_id", None) + self.exponential_decay_length_penalty = config_dict.get( + "exponential_decay_length_penalty", None + ) + self.finetuning_task = config_dict.get("finetuning_task", None) + self.forced_bos_token_id = config_dict.get("forced_bos_token_id", None) + self.forced_eos_token_id = config_dict.get("forced_eos_token_id", None) + self.hidden_act = config_dict.get("hidden_act", "silu") + self.id2label = config_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"}) + self.is_decoder = config_dict.get("is_decoder", False) + self.is_encoder_decoder = config_dict.get("is_encoder_decoder", False) + self.label2id = config_dict.get("label2id", {"LABEL_0": 0, "LABEL_1": 1}) + self.length_penalty = config_dict.get("length_penalty", 1.0) + self.max_length = config_dict.get("max_length", 20) + self.min_length = config_dict.get("min_length", 0) + self.model_type = config_dict.get("model_type", "idefics3") + self.no_repeat_ngram_size = config_dict.get("no_repeat_ngram_size", 0) + self.num_beam_groups = config_dict.get("num_beam_groups", 1) + self.num_beams = config_dict.get("num_beams", 1) + self.num_key_value_heads = config_dict.get("num_key_value_heads", 1) + self.num_return_sequences = config_dict.get("num_return_sequences", 1) + self.output_attentions = config_dict.get("output_attentions", False) + self.output_hidden_states = config_dict.get("output_hidden_states", False) + self.output_scores = config_dict.get("output_scores", False) + self.pad_token_id = config_dict.get("pad_token_id", 128002) + self.prefix = config_dict.get("prefix", None) + self.problem_type = config_dict.get("problem_type", None) + self.pruned_heads = config_dict.get("pruned_heads", {}) + self.qk_layer_norms_perceiver = config_dict.get( + "qk_layer_norms_perceiver", False + ) + self.remove_invalid_values = config_dict.get("remove_invalid_values", False) + self.repetition_penalty = config_dict.get("repetition_penalty", 1.0) + self.resampler_depth = config_dict.get("resampler_depth", 6) + self.resampler_head_dim = config_dict.get("resampler_head_dim", 96) + self.resampler_n_heads = config_dict.get("resampler_n_heads", 16) + self.resampler_n_latents = config_dict.get("resampler_n_latents", 64) + self.return_dict = config_dict.get("return_dict", True) + self.return_dict_in_generate = config_dict.get("return_dict_in_generate", False) + self.sep_token_id = config_dict.get("sep_token_id", None) + self.suppress_tokens = config_dict.get("suppress_tokens", None) + self.task_specific_params = config_dict.get("task_specific_params", None) + self.temperature = config_dict.get("temperature", 1.0) + self.tf_legacy_loss = config_dict.get("tf_legacy_loss", False) + self.tie_encoder_decoder = config_dict.get("tie_encoder_decoder", False) + self.tie_word_embeddings = config_dict.get("tie_word_embeddings", True) + self.tokenizer_class = config_dict.get("tokenizer_class", None) + self.top_k = config_dict.get("top_k", 50) + self.top_p = config_dict.get("top_p", 1.0) + self.torch_dtype = config_dict.get("torch_dtype", None) + self.torchscript = config_dict.get("torchscript", False) + self.transformers_version = config_dict.get("transformers_version", "4.43.2") + self.typical_p = config_dict.get("typical_p", 1.0) + self.use_bfloat16 = config_dict.get("use_bfloat16", False) + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + config.text_config.perceiver_config = PerceiverConfig( + config_dict=config.text_config.perceiver_config + ) + self.image_seq_len = config.text_config.perceiver_config.resampler_n_latents + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + # TODO: finish implementing the image token replacement + + # inputs_embeds = self.inputs_merger( + # input_ids=input_ids, + # inputs_embeds=inputs_embeds, + # image_hidden_states=image_hidden_states, + # ) + + # import ipdb; ipdb.set_trace() + # num_images, _, vision_hidden_size = image_hidden_states.shape + # special_image_token_mask = input_ids == self.image_token_id + # new_inputs_embeds = inputs_embeds.clone() + # reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to( + # inputs_embeds.dtype + # ) # cast to the dtype of the input_embeds to support quantized models + # new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states + # inputs_embeds = new_inputs_embeds + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits + + class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -707,7 +982,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ) config.quantize = None - self.connector = Idefics2Connector( + self.connector = Idefics3Connector( prefix=f"{prefix}.model.connector" if prefix else "model.connector", config=config, weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 82e409a6..04edd0a4 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(f"{prefix}.text_model", config, weights) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 81b4369b..61562f8a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -54,6 +54,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str if processor.image_processor.do_image_splitting: image_str *= 5 return image_str + if config.model_type == "idefics3": + image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}" + image_str = "" + return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) @@ -288,6 +292,7 @@ class VlmCausalLM(FlashCausalLM): **processor_kwargs, ) self.batch_class = batch_class + # import ipdb; ipdb.set_trace() super().__init__( model_id=model_id, revision=revision,