Backport to transformers==4.31.

This commit is contained in:
Nicolas Patry 2023-08-17 07:28:14 +00:00
parent 308ab7d5b9
commit 6c699d86bf

View File

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
"""Image processor class for Idefics.""" """Image processor class for Idefics."""
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union, Iterable
import numpy as np
from PIL import Image from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import resize, to_channel_dimension_format from transformers.image_transforms import resize, to_channel_dimension_format, rescale, normalize
from transformers.image_utils import ( from transformers.image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
@ -186,5 +187,78 @@ class IdeficsImageProcessor(BaseImageProcessor):
else: else:
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}") raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
def rescale(
self,
image: np.ndarray,
scale: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The rescaled image.
"""
# return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
# requires 4.32
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `Iterable[float]`):
Image mean to use for normalization.
std (`float` or `Iterable[float]`):
Image standard deviation to use for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
# TODO 4.32
return normalize(
image, mean=mean, std=std, data_format=data_format, **kwargs
)
import transformers import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor transformers.IdeficsImageProcessor = IdeficsImageProcessor