diff --git a/server/text_generation_server/models/custom_modeling/idefics_config.py b/server/text_generation_server/models/custom_modeling/idefics_config.py
new file mode 100644
index 00000000..34925087
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/idefics_config.py
@@ -0,0 +1,323 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Idefics model configuration"""
+import copy
+
+from transformers import PretrainedConfig
+
+IDEFICS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "HuggingFaceM4/idefics-9b": "https://huggingface.co/HuggingFaceM4/idefics-9b/blob/main/config.json",
+ "HuggingFaceM4/idefics-80b": "https://huggingface.co/HuggingFaceM4/idefics-80b/blob/main/config.json",
+}
+
+
+class IdeficsVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
+ Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Idefics-9B.
+ e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`)
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ intermediate_size (`int`, *optional*, defaults to 5120):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ image_num_channels (`int`, *optional*, defaults to `3`):
+ Number of image channels.
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization
+ testing).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+ model_type = "idefics"
+ attribute_map = {
+ "hidden_size": "embed_dim",
+ }
+
+ def __init__(
+ self,
+ embed_dim=768,
+ image_size=224,
+ intermediate_size=5120,
+ patch_size=14,
+ num_hidden_layers=32,
+ num_attention_heads=16,
+ num_channels=3,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs,
+ ):
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.intermediate_size = intermediate_size
+ self.patch_size = patch_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.hidden_act = hidden_act
+
+ super().__init__(**kwargs)
+
+
+class IdeficsPerceiverConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
+ Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Idefics-9B.
+ e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ use_resampler (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the resampler
+ resampler_n_latents (`int`, *optional*, defaults to ):
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
+ resampler_depth (`int`, *optional*, defaults to 6):
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
+ resampler_n_heads (`int`, *optional*, defaults to 16):
+ Number of heads in each Transformer block (for multi-headed self-attention).
+ resampler_head_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of each head projection in the Transformer block.
+ qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
+ Whether or not to use qk layer norms in perceiver
+ """
+ model_type = "idefics"
+
+ def __init__(
+ self,
+ use_resampler=False,
+ resampler_n_latents=64,
+ resampler_depth=6,
+ resampler_n_heads=16,
+ resampler_head_dim=96,
+ qk_layer_norms_perceiver=False,
+ **kwargs,
+ ):
+ self.use_resampler = use_resampler
+ self.resampler_n_latents = resampler_n_latents
+ self.resampler_depth = resampler_depth
+ self.resampler_n_heads = resampler_n_heads
+ self.resampler_head_dim = resampler_head_dim
+ self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
+
+ super().__init__(**kwargs)
+
+
+class IdeficsConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
+ Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Idefics-9B.
+ e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ additional_vocab_size (`int`, *optional`, defaults to 0):
+ Additional vocabulary size of the model, typically for the special "
" token. Additional vocab tokens
+ are always trainable whereas regular vocab tokens can be frozen or not.
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~IdeficsModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ alpha_initializer (`str`, *optional*, defaults to `"zeros"`):
+ Initialization type for the alphas.
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross
+ Attention.
+ alpha_type (`str`, *optional*, defaults to `"float"`):
+ Whether the gating alphas should be vectors or single floats.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0)
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1)
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2)
+ End of stream token id.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ cross_layer_interval (`int`, *optional*, default to 1)
+ Interval for cross attention (from text to image) layers.
+ qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k
+ freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers
+ freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`):
+ Exceptions to freezing text layers when `freeze_text_layers` is `True`
+ freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head
+ freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers
+ freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`):
+ Exceptions to freezing vision layers when `freeze_vision_layers` is `True`
+ use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler
+ vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict
+ perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict
+ Example:
+ ```python
+ >>> from transformers import IdeficsModel, IdeficsConfig
+ >>> # Initializing a Idefics idefics-9b style configuration
+ >>> configuration = IdeficsConfig()
+ >>> # Initializing a model from the idefics-9b style configuration
+ >>> model = IdeficsModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "idefics"
+ is_composition = True
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ additional_vocab_size=0,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ dropout=0.0,
+ hidden_act="silu",
+ initializer_range=0.02,
+ alpha_initializer="zeros",
+ alphas_initializer_range=0.0,
+ alpha_type="float",
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ cross_layer_interval=1,
+ qk_layer_norms=False,
+ freeze_text_layers=True,
+ freeze_text_module_exceptions=[],
+ freeze_lm_head=False,
+ freeze_vision_layers=True,
+ freeze_vision_module_exceptions=[],
+ use_resampler=False,
+ vision_config=None,
+ perceiver_config=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.additional_vocab_size = additional_vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.alpha_initializer = alpha_initializer
+ self.alphas_initializer_range = alphas_initializer_range
+ self.alpha_type = alpha_type
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+
+ self.cross_layer_interval = cross_layer_interval
+ self.qk_layer_norms = qk_layer_norms
+ self.freeze_vision_layers = freeze_vision_layers
+
+ self.freeze_text_layers = freeze_text_layers
+ self.freeze_text_module_exceptions = freeze_text_module_exceptions
+ self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
+ self.freeze_lm_head = freeze_lm_head
+
+ self.use_resampler = use_resampler
+
+ if perceiver_config is None:
+ self.perceiver_config = IdeficsPerceiverConfig()
+ elif isinstance(perceiver_config, dict):
+ self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config)
+ elif isinstance(perceiver_config, IdeficsPerceiverConfig):
+ self.perceiver_config = perceiver_config
+
+ if vision_config is None:
+ self.vision_config = IdeficsVisionConfig()
+ elif isinstance(vision_config, dict):
+ self.vision_config = IdeficsVisionConfig(**vision_config)
+ elif isinstance(vision_config, IdeficsVisionConfig):
+ self.vision_config = vision_config
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
+ # of this object many attributes have default values and haven't yet been overridden.
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+
+ output["vision_config"] = self.vision_config.to_dict()
+ output["perceiver_config"] = self.perceiver_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+
+ return output
diff --git a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py
new file mode 100644
index 00000000..727f94c6
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py
@@ -0,0 +1,166 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Idefics."""
+
+from typing import Callable, Dict, List, Optional, Union
+
+from PIL import Image
+
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_transforms import resize, to_channel_dimension_format
+from transformers.image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+)
+from transformers import TensorType, is_torch_available
+
+
+IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073]
+IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711]
+
+
+def convert_to_rgb(image):
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
+ # for transparent images. The call to `alpha_composite` handles this case
+ if image.mode == "RGB":
+ return image
+
+ image_rgba = image.convert("RGBA")
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
+ alpha_composite = Image.alpha_composite(background, image_rgba)
+ alpha_composite = alpha_composite.convert("RGB")
+ return alpha_composite
+
+
+class IdeficsImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Idefics image processor.
+ Args:
+ image_size (`int`, *optional*, defaults to `224`):
+ Resize to image size
+ image_num_channels (`int`, *optional*, defaults to `3`):
+ Number of image channels.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ image_num_channels: Optional[int] = 3,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.image_num_channels = image_num_channels
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ image_num_channels: Optional[int] = 3,
+ image_size: Optional[Dict[str, int]] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ transform: Callable = None,
+ **kwargs,
+ ) -> TensorType.PYTORCH:
+ """
+ Preprocess a batch of images.
+ Args:
+ images (`ImageInput`):
+ A list of images to preprocess.
+ image_size (`int`, *optional*, defaults to `self.image_size`):
+ Resize to image size
+ image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`):
+ Number of image channels.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can
+ be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess`
+ method. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ transform (`Callable`, *optional*, defaults to `None`):
+ A custom transform function that accepts a single image can be passed for training. For example,
+ `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is
+ assumed - and then a preset of inference-specific transforms will be applied to the images
+ Returns:
+ a PyTorch tensor of the processed images
+ """
+ image_size = image_size if image_size is not None else self.image_size
+ image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ size = (image_size, image_size)
+
+ if len(images) == 0:
+ return []
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ # For training a user needs to pass their own set of transforms as a Callable.
+ # For reference this is what was used in the original IDEFICS training:
+ # transform = transforms.Compose([
+ # convert_to_rgb,
+ # transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
+ # transforms.ToTensor(),
+ # transforms.Normalize(mean=image_mean, std=image_std),
+ # ])
+ if transform is not None:
+ if not is_torch_available():
+ raise ImportError("To pass in `transform` torch must be installed")
+ import torch
+
+ images = [transform(x) for x in images]
+ return torch.stack(images)
+
+ # for inference we do the exact transforms that were used to train IDEFICS
+ images = [convert_to_rgb(x) for x in images]
+ # further transforms expect numpy arrays
+ images = [to_numpy_array(x) for x in images]
+ images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
+ images = [self.rescale(image=image, scale=1 / 255) for image in images]
+ images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
+ images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
+ # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
+ images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
+
+ return images
+import transformers
+transformers.IdeficsImageProcessor = IdeficsImageProcessor
diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py
index 54bbecf7..90eb0463 100644
--- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py
@@ -36,7 +36,7 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
-from transformers import IdeficsConfig
+from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_vision import IdeficsVisionTransformer
from text_generation_server.models.custom_modeling.idefics_perceiver import IdeficsPerceiverResampler
from text_generation_server.utils.layers import (
diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
index c0e5b400..def78390 100644
--- a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
+++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
@@ -41,7 +41,6 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
-from transformers import IdeficsConfig
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
@@ -53,7 +52,7 @@ class IdeficsPerceiverResampler(nn.Module):
def __init__(
self,
prefix,
- config: IdeficsConfig,
+ config,
embed_dim: int,
depth: int,
n_heads: int,
@@ -223,7 +222,7 @@ class IdeficsMLP(nn.Module):
def __init__(self,
prefix,
intermediate_size,
- config: IdeficsConfig,
+ config,
weights,
):
"""Simple MLP block with intermediate_size and embedding size"""
diff --git a/server/text_generation_server/models/custom_modeling/idefics_processing.py b/server/text_generation_server/models/custom_modeling/idefics_processing.py
new file mode 100644
index 00000000..e24fc7bd
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/idefics_processing.py
@@ -0,0 +1,413 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for IDEFICS.
+"""
+
+from typing import Callable, List, Optional, Union
+from urllib.parse import urlparse
+
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
+from transformers.utils import TensorType, is_torch_available
+from text_generation_server.models.custom_modeling.idefics_image_processing import IdeficsImageProcessor
+
+
+if is_torch_available():
+ import torch
+
+
+IMAGE_TOKEN = ""
+
+
+# copied from m4.training.packing
+def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
+ # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
+
+ # If any of images index are more than num_classes, set them to -1.
+ # Words after the max number of images allowed have been seen don't attend on anything
+ if num_classes != -1:
+ incremental_mask[incremental_mask >= num_classes] = -1
+
+ negatives = incremental_mask == -1
+ incremental_mask[negatives] = 0
+ attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
+ attn_mask[negatives, :] = 0
+ return attn_mask
+
+
+# copied from m4.training.packing
+def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
+ image_attention_mask = torch.full_like(input_ids, fill_value=-1)
+ next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
+ image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
+ eod_token_id = tokenizer.eos_token_id
+ for batch_idx in range(input_ids.size(0)):
+ count = -1
+ seen_eod = False
+ for idx, token_id in enumerate(input_ids[batch_idx]):
+ if token_id == image_token_id:
+ count += 1
+ image_attention_mask[batch_idx][idx] = count
+ seen_eod = False
+ else:
+ image_attention_mask[batch_idx][idx] = count
+
+ if seen_eod:
+ image_attention_mask[batch_idx][idx] = -1
+
+ if token_id == eod_token_id:
+ seen_eod = True
+
+ for batch_idx in range(input_ids.size(0)):
+ count = -1
+ seen_eod = False
+ for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
+ token_id = input_ids[batch_idx][idx]
+ if token_id == image_token_id:
+ count += 1
+ next_image_attention_mask[batch_idx][idx] = count
+ seen_eod = False
+ else:
+ next_image_attention_mask[batch_idx][idx] = count
+
+ if token_id == eod_token_id:
+ seen_eod = True
+
+ if seen_eod:
+ next_image_attention_mask[batch_idx][idx] = -1
+
+ non_negative_indices = next_image_attention_mask[batch_idx] != -1
+ next_image_attention_mask[batch_idx][non_negative_indices] -= count
+ next_image_attention_mask[batch_idx][non_negative_indices] *= -1
+
+ return image_attention_mask, next_image_attention_mask
+
+
+def is_url(string):
+ """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
+ invalidated the url"""
+ if " " in string:
+ return False
+ result = urlparse(string)
+ return all([result.scheme, result.netloc])
+
+
+class IdeficsProcessor(ProcessorMixin):
+ r"""
+ Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor.
+
+ [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See
+ the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
+
+ Args:
+ image_processor (`IdeficsImageProcessor`):
+ An instance of [`IdeficsImageProcessor`]. The image processor is a required input.
+ tokenizer (`LlamaTokenizerFast`):
+ An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
+ image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
+ """
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "IdeficsImageProcessor"
+ tokenizer_class = "LlamaTokenizerFast"
+
+ def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_utterance_token=None, **kwargs):
+ if image_processor is None:
+ raise ValueError("You need to specify an `image_processor`.")
+ if tokenizer is None:
+ raise ValueError("You need to specify a `tokenizer`.")
+
+ super().__init__(image_processor, tokenizer)
+ self.current_processor = self.image_processor
+ self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
+
+ self.default_image_dims = (
+ self.image_processor.image_num_channels,
+ self.image_processor.image_size,
+ self.image_processor.image_size,
+ )
+
+ self.tokenizer_was_trained_with_end_of_utterance_token = (
+ True
+ if "" in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
+ else False
+ )
+
+ def __call__(
+ self,
+ prompts: Union[List[TextInput], List[List[TextInput]]],
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ transform: Callable = None,
+ add_eos_token=False,
+ add_end_of_utterance_token=None,
+ debug=False,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+ ) -> BatchEncoding:
+ """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
+ the model was trained on and prepares the image pixel values for the model to process.
+
+ Args:
+ prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
+ either a single prompt or a batched list of prompts - see the detailed description immediately after
+ the end of the arguments doc section.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ transform (`Callable`, *optional*):
+ A custom transform function that accepts a single image can be passed for training. For example,
+ `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
+ set of transforms will be applied to the images
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Adds `eos_token` at the end of the final prompt if True`
+ add_end_of_utterance_token (`bool`, *optional*)
+ Whether to automatically add `` after each prompt's text input (unless followed by an
+ image). If `None` the tokenizer will be checked instead and if this token is found in
+ `additional_special_tokens` then the value will be `True`.
+ debug (`bool`, *optional*, defaults to `False`):
+ `True` value will help debug prompt generation by dumping useful information
+ return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
+ The type of tensors to return. Can be one of:
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+
+ Returns:
+ a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be
+ directly passed to `model.generate`
+
+ Detailed explanation:
+
+ Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
+
+ An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
+
+ When the processor encounters an image it'll inject ``
+ entry into the prompt.
+
+ Example:
+
+ ```python
+ checkpoint = "HuggingFaceM4/idefics-9b"
+ processor = AutoProcessor.from_pretrained(checkpoint)
+ url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg"
+ img = processor.image_processor.fetch_images([url])[0]
+
+ prompts = [
+ "User:",
+ img,
+ "Describe this image.\nAssistant: An image of two kittens in grass.\n",
+ "User:",
+ "https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg",
+ "Describe this image.\nAssistant:",
+ ]
+
+ inputs = processor(prompts, return_tensors="pt")
+ generated_ids = model.generate(**inputs, max_length=100)
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```
+
+ In this example the `prompts` will be converted into:
+
+ ```
+ User:Describe this image.
+ Assistant: An image of two kittens in grass.
+ User:Describe this image.
+ Assistant:'
+ ```
+
+ and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the
+ `pixel_values` dict entry of the return value.
+
+ This example also examplifies that images can be passed as objects or as text urls. It can be seen that the
+ first image is passed as object and the second one as a url.
+
+ To do training do:
+
+ ```python
+ image_transform = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
+ ),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=self.image_mean, std=self.image_std),
+ ]
+ )
+ inputs = processor(prompts, transform=image_transform, return_tensors="pt")
+ ```
+
+ In order to help debug prompt generation enable `debug=True` which will show you what's happening.
+
+ """
+
+ # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
+ if add_end_of_utterance_token is None:
+ add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
+
+ # turn non-batched prompts into batched
+ if not any(isinstance(i, list) for i in prompts):
+ prompts = [prompts]
+
+ fake_token = ""
+ image_token = ""
+ end_of_utterance_token = ""
+
+ def image_tokens(last_was_image):
+ if last_was_image:
+ return image_token + fake_token
+ else:
+ return fake_token + image_token + fake_token
+
+ all_texts = []
+ all_images = []
+ for sample in prompts:
+ # the model was trained on samples starting with
+ full_text = f"{self.tokenizer.bos_token}"
+
+ # an image can either be an image object in the item or the url, everything else is a verbatim prompt text
+ image_objects = []
+ last_was_image = False
+ last_was_text = False
+ for i, item in enumerate(sample):
+ if i > 0:
+ last_was_text = True if not last_was_image else False
+
+ if isinstance(item, str):
+ item = item.strip(" ")
+ if is_url(item):
+ image = self.image_processor.fetch_images(item)
+ full_text += image_tokens(last_was_image)
+ image_objects.append(image)
+ last_was_image = True
+ else:
+ # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!)
+ if add_end_of_utterance_token and last_was_text:
+ full_text += end_of_utterance_token
+ full_text += item
+ last_was_image = False
+ else:
+ # must be an image obj
+ full_text += image_tokens(last_was_image)
+ image_objects.append(item)
+ last_was_image = True
+
+ if add_eos_token:
+ full_text += self.tokenizer.eos_token
+
+ if debug is True:
+ print(f"{full_text=}")
+
+ image_objects = self.image_processor(image_objects, transform=transform)
+
+ text_encoding = self.tokenizer(
+ text=full_text,
+ add_special_tokens=False,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ )
+
+ all_texts.append(text_encoding["input_ids"])
+ all_images.append(image_objects)
+
+ max_seq_len = max(len(x) for x in all_texts)
+
+ # max_num_images has to be at least 1 even when there are no images
+ max_num_images = max(len(x) for x in all_images)
+ max_num_images = max(1, max_num_images)
+
+ at_least_one_image = sum(len(x) for x in all_images) > 0
+ output_input_ids = []
+ output_images = []
+ output_attention_masks = []
+ for text, images in zip(all_texts, all_images):
+ padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len
+ unpadded_seq_len = len(text)
+ start = max_seq_len - unpadded_seq_len
+ padded_input_ids[start:] = text[:max_seq_len]
+
+ attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
+ attention_mask[start:] = 1
+
+ image_count = padded_input_ids.count(self.image_token_id)
+ local_max_num_images = min(image_count, max_num_images)
+
+ current_images = images[:local_max_num_images]
+
+ if len(current_images) > 0:
+ padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
+ padded_image_tensor[: current_images.size(0)] = current_images
+ else:
+ padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
+
+ output_images.append(padded_image_tensor)
+ output_input_ids.append(torch.tensor(padded_input_ids))
+
+ output_attention_masks.append(attention_mask)
+
+ output_input_ids = torch.stack(output_input_ids)
+ output_images = torch.stack(output_images)
+ output_attention_masks = torch.stack(output_attention_masks)
+
+ if at_least_one_image:
+ image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer)
+ image_attention_mask = incremental_to_binary_attention_mask(
+ image_attention_mask, num_classes=max_num_images
+ )
+ else:
+ # in full language mode we set the image mask to all-0s
+ image_attention_mask = torch.zeros(
+ output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
+ )
+
+ return BatchFeature(
+ data={
+ "input_ids": output_input_ids,
+ "attention_mask": output_attention_masks,
+ "pixel_values": output_images,
+ "image_attention_mask": image_attention_mask,
+ }
+ )
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py
index 6caf2918..d933d7c1 100644
--- a/server/text_generation_server/models/custom_modeling/idefics_vision.py
+++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py
@@ -28,7 +28,6 @@ from transformers.utils import (
ModelOutput,
logging,
)
-from transformers.models.idefics.configuration_idefics import IdeficsVisionConfig
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
@@ -69,7 +68,7 @@ class IdeficsVisionModelOutput(ModelOutput):
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
class IdeficsVisionEmbeddings(nn.Module):
- def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -253,7 +252,7 @@ class IdeficsVisionMLP(nn.Module):
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision
class IdeficsVisionEncoderLayer(nn.Module):
- def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = IdeficsVisionAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights)
@@ -316,7 +315,7 @@ class IdeficsVisionEncoder(nn.Module):
config: IdeficsVisionConfig
"""
- def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
@@ -417,7 +416,7 @@ class IdeficsVisionEncoder(nn.Module):
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
class IdeficsVisionTransformer(nn.Module):
- def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py
index b507da19..07fea1f2 100644
--- a/server/text_generation_server/models/idefics.py
+++ b/server/text_generation_server/models/idefics.py
@@ -10,6 +10,9 @@ from transformers import (
)
from text_generation_server.models import IdeficsCausalLM
+from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
+from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
+from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
@@ -38,7 +41,7 @@ class IDEFICSSharded(IdeficsCausalLM):
dtype = torch.float32
self.device, self.dtype = device, dtype
- config = AutoConfig.from_pretrained(
+ config = IdeficsConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
@@ -46,14 +49,14 @@ class IDEFICSSharded(IdeficsCausalLM):
config.quantize = quantize
config.vision_config.quantize = quantize
- tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
- self.processor = AutoProcessor.from_pretrained(
+ self.processor = IdeficsProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py
index 7245bc65..2d1f418b 100644
--- a/server/text_generation_server/models/idefics_causal_lm.py
+++ b/server/text_generation_server/models/idefics_causal_lm.py
@@ -8,7 +8,8 @@ import json
from dataclasses import dataclass
from opentelemetry import trace
-from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin, IdeficsForVisionText2Text
+from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin
+from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model