From eaf9448b485562e7fbc32ce1432b49825426011f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 14 Aug 2023 16:05:47 +0000 Subject: [PATCH] Adding Idefics multi modal model. Co-Authored-By: Victor Sanh --- .../text_generation_server/models/__init__.py | 11 + .../custom_modeling/idefics_modeling.py | 1318 +++++++++++++++++ .../custom_modeling/idefics_perceiver.py | 246 +++ .../models/custom_modeling/idefics_vision.py | 474 ++++++ .../text_generation_server/models/idefics.py | 112 ++ .../models/idefics_causal_lm.py | 837 +++++++++++ server/text_generation_server/server.py | 23 +- server/text_generation_server/utils/layers.py | 24 + 8 files changed, 3039 insertions(+), 6 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/idefics_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/idefics_perceiver.py create mode 100644 server/text_generation_server/models/custom_modeling/idefics_vision.py create mode 100644 server/text_generation_server/models/idefics.py create mode 100644 server/text_generation_server/models/idefics_causal_lm.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 621652e8..518e1a15 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,6 +18,8 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded +from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM +from text_generation_server.models.idefics import IDEFICSSharded # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -40,6 +42,7 @@ __all__ = [ "OPTSharded", "T5Sharded", "get_model", + "IDEFICSSharded", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." @@ -248,6 +251,14 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == "idefics": + return IDEFICSSharded( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if sharded: raise ValueError("sharded is not supported for AutoModel") diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py new file mode 100644 index 00000000..42c079d5 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -0,0 +1,1318 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics model.""" +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PretrainedConfig +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers import IdeficsConfig +from text_generation_server.models.custom_modeling.idefics_vision import IdeficsVisionTransformer +from text_generation_server.models.custom_modeling.idefics_perceiver import IdeficsPerceiverResampler +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, + PositionRotaryEmbedding, + FastLinear, +) + + +# logger = logging.get_logger(__name__) + +# _CONFIG_FOR_DOC = "IdeficsConfig" + +# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [ +# "HuggingFaceM4/idefics-9b", +# "HuggingFaceM4/idefics-80b", +# # See all Idefics models at https://huggingface.co/models?filter=idefics +# ] + + +def expand_inputs_for_generation( + input_ids, + expand_size=1, + is_encoder_decoder=False, + attention_mask=None, + encoder_outputs=None, + **model_kwargs, +): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) + model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( + 0, expanded_return_idx + ) + model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) + ) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + + +def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # must have this key set to at least None + model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) + + # update past + if "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past"] = outputs.past_buckets_states + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention masks + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + if "image_attention_mask" in model_kwargs: + image_attention_mask = model_kwargs["image_attention_mask"] + last_mask = image_attention_mask[:, -1, :].unsqueeze(1) + model_kwargs["image_attention_mask"] = last_mask + + return model_kwargs + + +def prepare_inputs_for_generation(input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + + pixel_values = kwargs.get("pixel_values", None) + image_attention_mask = kwargs.get("image_attention_mask", None) + # if pixel_values is None or image_attention_mask is None: + # raise ValueError("pixel values and image attention mask cannot be None") + + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "pixel_values": pixel_values, + "image_attention_mask": image_attention_mask, + } + + +def freeze_model(model, module_exceptions=[]): + mapping = { + "LayerNorm": nn.LayerNorm, + "Linear": nn.Linear, + "Embedding": nn.Embedding, + } + module_exceptions_mapped = [mapping[m] for m in module_exceptions] + for module in model.modules(): + if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]): + module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes + else: + module.requires_grad_(False) + return model + + +class IdeficsDecoupledPartialTPEmbedding(nn.Module): + def __init__( + self, + config, + weights, + ): + super().__init__() + self.num_embeddings = config.vocab_size + self.weight = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.additional_weight = nn.Parameter(weights.get_tensor(f"model.embed_tokens.additional_embedding.weight")) + + def forward(self, input_ids): + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = torch.nn.functional.embedding(input_ids_additional_vocab - self.num_embeddings, self.additional_weight) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = self.weight(input_ids) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + +class IdeficsDecoupledTensorParallelLinear(nn.Module): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. + """ + + def __init__( + self, + config, + weights, + ) -> None: + super().__init__() + self.fc = TensorParallelHead.load( + config=config, prefix="lm_head", weights=weights + ) + self.additional_fc = FastLinear.load( + config=config, prefix="lm_head.additional_fc", weights=weights, bias=False, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.fc(input) + additional_features = self.additional_fc(input) + output = torch.cat((output, additional_features), -1) + + return output + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( + self.in_features, + self.out_features, + self.out_additional_features, + self.bias is not None, + self.partially_freeze, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# this was adapted from LlamaRMSNorm +class IdeficsRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + IdeficsRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# this was adapted from LlamaRotaryEmbedding +class IdeficsEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# this was adapted from LlamaMLP +class IdeficsMLP(nn.Module): + def __init__( + self, + config, + prefix, + weights, + ): + super().__init__() + self.gate_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.gate_proj", weights=weights, bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, + ) + self.up_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.up_proj", weights=weights, bias=False, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# this was adapted from LlamaAttention +class IdeficsAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config, + prefix, + weights, + qk_layer_norms: bool = False, + is_cross_attention: bool = False, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.dropout = config.dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.is_cross_attention = is_cross_attention + + # if not hasattr(nn.functional, "scaled_dot_product_attention"): + # raise ValueError("this model requires pytorch 2.0 or higher") + + process_group = weights.process_group + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + + if self.is_cross_attention: + # kv_input_dim = ( + # self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim + # ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=False + ) + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=False + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=False + ) + else: + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=False + ) + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=False + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=False + ) + self.o_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.o_proj", weights=weights, bias=False + ) + # self.rotary_emb = PositionRotaryEmbedding.load( + # prefix=f"{prefix}.rotary_emb", weights=weights + # ) + self.rotary_emb = IdeficsEmbedding(self.head_dim, device="cuda:0") #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash + + self.qk_layer_norms = qk_layer_norms + if self.qk_layer_norms: + self.q_layer_norm = IdeficsRMSNorm( + prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps + ) + self.k_layer_norm = IdeficsRMSNorm( + prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = self.is_cross_attention or key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + if not is_cross_attention: + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + else: + _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = ( + self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + if not is_cross_attention: + cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len)) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # cos, sin = self.rotary_emb.get_cos_sin( + # position_ids=torch.arange(), + # max_s=max(kv_seq_len, q_len), + # dtype=hidden_states.dtype, + # ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.qk_layer_norms: + query_states = self.q_layer_norm(query_states) + key_states = self.k_layer_norm(key_states) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_output = nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + attn_weights = None + if output_attentions: + logger.warning_once( + "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead" + ) + + return attn_output, attn_weights, past_key_value + + +# this was adapted from LlamaDecoderLayer +class IdeficsDecoderLayer(nn.Module): + def __init__(self, layer_id: int, config: IdeficsConfig, weights): + super().__init__() + self.process_group = weights.process_group + self.hidden_size = config.hidden_size + prefix = f"model.layers.{layer_id}" + self.self_attn = IdeficsAttention( + config=config, + prefix=f"{prefix}.self_attn", + weights=weights, + qk_layer_norms=False, + is_cross_attention=False, + ) + self.mlp = IdeficsMLP( + config=config, + prefix=f"{prefix}.mlp", + weights=weights, + ) + self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) + self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps) + self.dropout = config.dropout + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class IdeficsGatedCrossAttentionLayer(nn.Module): + def __init__(self, layer_id, config: IdeficsConfig, weights): + super().__init__() + self.process_group = weights.process_group + self.hidden_size = config.hidden_size + prefix = f"model.gated_cross_attn_layers.{layer_id}" + self.cross_attn = IdeficsAttention( + config=config, + prefix=f"{prefix}.cross_attn", + weights=weights, + qk_layer_norms=True, + is_cross_attention=True, + ) + self.mlp = IdeficsMLP( + config=config, + prefix=f"{prefix}.mlp", + weights=weights, + ) + self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) + self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps) + self.config = config.dropout + + self.act_cross_attn = nn.Tanh() + self.act_dense = nn.Tanh() + + self.alpha_cross_attn = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_cross_attn")) + self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense")) + + if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): + raise ValueError("Alpha parameters not initialized correctly!") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_hidden_states: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + no_images: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored + """ + if image_hidden_states is None: + raise ValueError( + "`image_hidden_states` is required for Idefics cross attention module which are visual features to be" + " conditioned on." + ) + + if past_key_value is not None: + raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.") + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=image_hidden_states, + attention_mask=image_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + # when there are no images the model is used in pure language mode + gate = 0 if no_images else 1 + hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`IdeficsConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# @add_start_docstrings( +# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", +# LLAMA_START_DOCSTRING, +# ) +class IdeficsPreTrainedModel(PreTrainedModel): + config_class = IdeficsConfig + # base_model_prefix = "model" + # supports_gradient_checkpointing = True + # _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] + + # def _init_weights(self, module): + # # important: this ported version of Idefics isn't meant for training from scratch - only + # # inference and fine-tuning - so the proper init weights code has been removed - the m4 code + # # base should be used for training from scratch and it contains the correct code. + # std = self.config.initializer_range + # if isinstance(module, nn.Linear): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.bias is not None: + # module.bias.data.zero_() + # elif isinstance(module, nn.Embedding): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.padding_idx is not None: + # module.weight.data[module.padding_idx].zero_() + + # def _set_gradient_checkpointing(self, module, value=False): + # if isinstance(module, IdeficsModel): + # module.gradient_checkpointing = value + + +# LLAMA_INPUTS_DOCSTRING = r""" +# Args: +# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): +# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide +# it. + +# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and +# [`PreTrainedTokenizer.__call__`] for details. + +# [What are input IDs?](../glossary#input-ids) +# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): +# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + +# - 1 for tokens that are **not masked**, +# - 0 for tokens that are **masked**. + +# [What are attention masks?](../glossary#attention-mask) + +# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and +# [`PreTrainedTokenizer.__call__`] for details. + +# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see +# `past_key_values`). + +# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] +# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more +# information on the default strategy. + +# - 1 indicates the head is **not masked**, +# - 0 indicates the head is **masked**. +# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): +# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, +# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) +# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): +# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape +# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape +# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + +# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention +# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + +# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that +# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all +# `decoder_input_ids` of shape `(batch_size, sequence_length)`. +# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): +# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This +# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the +# model's internal embedding lookup matrix. +# use_cache (`bool`, *optional*): +# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see +# `past_key_values`). +# output_attentions (`bool`, *optional*): +# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned +# tensors for more detail. +# output_hidden_states (`bool`, *optional*): +# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for +# more detail. +# return_dict (`bool`, *optional*): +# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +# """ + + +# @add_start_docstrings( +# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", +# LLAMA_START_DOCSTRING, +# ) +class IdeficsModel(IdeficsPreTrainedModel): + # """ + # Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`] + + # Args: + # config: IdeficsConfig + # """ + + def __init__(self, config: IdeficsConfig, weights): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = IdeficsDecoupledPartialTPEmbedding( + config=config, + weights=weights, + ) + + self.image_size = config.vision_config.image_size + self.vision_config = config.vision_config + self.vision_model = IdeficsVisionTransformer( + prefix="model.vision_model", + config=config.vision_config, + weights=weights, + ) + + # Perceiver Resampler + if config.use_resampler: + perceiver_config = config.perceiver_config + self.perceiver_resampler = IdeficsPerceiverResampler( + prefix=f"model.perceiver_resampler", + config=config, + embed_dim=config.vision_config.embed_dim, + depth=perceiver_config.resampler_depth, + n_heads=perceiver_config.resampler_n_heads, + head_dim=perceiver_config.resampler_head_dim, + n_latents=perceiver_config.resampler_n_latents, + weights=weights, + ) + + self.layers = nn.ModuleList( + [ + IdeficsDecoderLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.cross_layer_interval = config.cross_layer_interval + num_cross_layers = config.num_hidden_layers // self.cross_layer_interval + self.gated_cross_attn_layers = nn.ModuleList( + [ + IdeficsGatedCrossAttentionLayer(layer_id, config, weights) + for layer_id in range(num_cross_layers)] + ) + # self.gradient_checkpointing = False + + self.norm = IdeficsRMSNorm(prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps) + + # self.gradient_checkpointing = False + # Initialize weights and apply final processing + # self.post_init() + + # self.freeze_relevant_params(config) + + # def freeze_relevant_params(self, config=None): + # if config is None: + # config = self.config + + # if config.freeze_text_layers: + # self.freeze_text_layers(config.freeze_text_module_exceptions) + + # if config.freeze_vision_layers: + # freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) + + # def freeze_text_layers(self, module_exceptions=[]): + # for module in [self.layers, self.norm]: + # freeze_model(module, module_exceptions=module_exceptions) + + # def freeze_vision_layers(self, module_exceptions=[]): + # freeze_model(self.vision_model, module_exceptions=module_exceptions) + + # def get_input_embeddings(self): + # return self.embed_tokens + + # def set_input_embeddings(self, value): + # self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + device = input_ids.device if input_ids is not None else inputs_embeds.device + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + elif position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + no_images = False + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values and image_embeddings have to be not-None.") + + elif pixel_values is not None and image_embeddings is not None: + raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time") + + elif pixel_values is not None: + no_images = len(torch.nonzero(pixel_values)) == 0 + pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility + batch_size, num_images = pixel_values.shape[:2] + pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:]) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state + + elif image_embeddings is not None: + batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size() + image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device) + image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) + + if self.config.use_resampler: + image_hidden_states = self.perceiver_resampler(image_hidden_states) + image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) + image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size) + # # Hack to use the model in full language modeling mode + # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device) + # Make image_attention_mask compatible with hidden states + text_seq_len = image_attention_mask.size(1) + image_attention_mask = image_attention_mask.unsqueeze(-1) + image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len) + image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len) + + if image_hidden_states is not None: + image_batch_size, image_sequence_length, _ = image_hidden_states.size() + image_hidden_shape = (image_batch_size, image_sequence_length) + if image_attention_mask is None: + image_attention_mask = torch.ones(image_hidden_shape, device=device) + image_attention_mask = self.invert_attention_mask(image_attention_mask) + else: + image_attention_mask = None + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # if self.gradient_checkpointing and self.training: + # if use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + # use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + def vblock( + main_block, + hidden_states, + attention_mask, + position_ids, + past_key_value, + image_hidden_states, + image_attention_mask, + output_attentions, + use_cache, + no_images, + layer_idx, + cross_layer_interval, + gated_cross_attn_layers, + ): + # TODO(ls): Add cross attention values to respective lists + if layer_idx % cross_layer_interval == 0: + xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] + outputs = xblock( + hidden_states, + attention_mask=attention_mask, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + output_attentions=output_attentions, + use_cache=use_cache, + past_key_value=None, # not implemented + no_images=no_images, + ) + hidden_states = outputs[0] + + layer_outputs = main_block( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + return layer_outputs + + # if self.gradient_checkpointing and self.training: + # past_key_value = None + # if use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + # use_cache = False + + # layer_outputs = torch.utils.checkpoint.checkpoint( + # vblock, + # decoder_layer, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_value, + # image_hidden_states, + # image_attention_mask, + # output_attentions, + # use_cache, + # no_images, + # idx, + # self.cross_layer_interval, + # self.gated_cross_attn_layers, + # ) + # else: + layer_outputs = vblock( + decoder_layer, + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + output_attentions=output_attentions, + use_cache=use_cache, + no_images=no_images, + layer_idx=idx, + cross_layer_interval=self.cross_layer_interval, + gated_cross_attn_layers=self.gated_cross_attn_layers, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class IdeficsForVisionText2Text(IdeficsPreTrainedModel): + def __init__( + self, + config, + weights, + ): + super().__init__(config) + self.model = IdeficsModel( + config=config, + weights=weights, + ) + + self.lm_head = IdeficsDecoupledTensorParallelLinear( + config=config, + weights=weights, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + from loguru import logger; logger.info(f"forward in idefics_modeling.py - {input_ids.size()=}") + from loguru import logger; logger.info(f"forward in idefics_modeling.py - {attention_mask.size()=}") + from loguru import logger; logger.info(f"forward in idefics_modeling.py - {position_ids.size()=}") + from loguru import logger; logger.info(f"forward in idefics_modeling.py - {pixel_values.size()=} {pixel_values.sum()=}") + from loguru import logger; logger.info(f"forward in idefics_modeling.py - {image_attention_mask.size()=} {image_attention_mask.sum()=}") + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_embeddings=image_embeddings, + image_attention_mask=image_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + # if labels is not None: + # # Shift so that tokens < n predict n + # if attention_mask is not None: + # shift_attention_mask = attention_mask[..., 1:] + # shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() + # shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() + # else: + # shift_logits = logits[..., :-1, :].contiguous() + # shift_labels = labels[..., 1:].contiguous() + # # Flatten the tokens + # loss_fct = CrossEntropyLoss() + # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + # if not return_dict: + # output = (logits,) + outputs[1:] + # return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) + unwanted_kwargs = ["token_type_ids"] + for kwarg in unwanted_kwargs: + inputs.pop(kwarg, None) + return inputs + + @staticmethod + def _expand_inputs_for_generation( + *args, + **model_kwargs, + ): + return expand_inputs_for_generation(*args, **model_kwargs) + + @staticmethod + def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder) + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past \ No newline at end of file diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py new file mode 100644 index 00000000..c946ee7b --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py @@ -0,0 +1,246 @@ +# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License. +# +# MIT License +# +# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +""" + +Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially +time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note +that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to +prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that +to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore. + +References: + - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model + - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch + +""" +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from transformers import IdeficsConfig +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + +EPS=1e-5 + +class IdeficsPerceiverResampler(nn.Module): + def __init__( + self, + prefix, + config: IdeficsConfig, + embed_dim: int, + depth: int, + n_heads: int, + head_dim: int, + n_latents: int, + weights, + ) -> None: + """ + Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or + MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then + returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed + to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. + Could be e.g., VIT embed_dim, ResNet pool dim, and so on. + + Args: + config (`IdeficsConfig`): config object + embed_dim (`int`): The size of each embedding vector + depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). + n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention). + head_dim (`int`): Dimensionality of each head projection in the Transformer block. + n_latents (`int`): + Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). + + """ + super().__init__() + self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents + self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver + + # Create Latents for Perceiver + self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents")) + + self.intermediate_dim = ( + self.embed_dim * 4 + if not hasattr(config.vision_config, "embed_dim") + else config.vision_config.embed_dim * 4 + ) + # Create Transformer Blocks + self.blocks = nn.ModuleList( + [ + nn.ModuleList( + [ + IdeficsPerceiverAttention( + prefix=f"{prefix}.blocks.{layer_id}.0", + config=config, + embed_dim=self.embed_dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + qk_layer_norms=self.qk_layer_norms, + weights=weights, + ), + IdeficsMLP( + prefix=f"{prefix}.blocks.{layer_id}.1", + intermediate_size=self.intermediate_dim, + config=config, + weights=weights + ), + ] + ) + for layer_id in range(depth) + ] + ) + self.layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS) + + def forward(self, context: torch.Tensor) -> torch.Tensor: + """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" + # einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) + latents = self.latents.repeat(context.shape[0], 1, 1) + + # Feed through Perceiver Attention blocks... + for attn, ff in self.blocks: + latents = attn(context, latents) + latents + latents = ff(latents) + latents + + return self.layer_norm(latents) + + +class IdeficsPerceiverAttention(nn.Module): + def __init__(self, + prefix, + config, + embed_dim: int, + n_heads: int, + head_dim: int, + qk_layer_norms: bool, + weights + ) -> None: + """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" + super().__init__() + self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim + self.qk_layer_norms = qk_layer_norms + # Normalization & Scaling + self.context_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS) + self.latents_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS) + if self.qk_layer_norms: + self.q_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS) + self.k_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS) + + self.qk_scale = self.head_dim**-0.5 + + process_group = weights.process_group + if n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + + # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). + self.q_proj = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False + ) + self.k_proj = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False + ) + self.v_proj = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False + ) + + self.output_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False + ) + + def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """ + Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! + + Args: + context (`torch.Tensor`): + Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample. + latents (`torch.Tensor`): + Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to. + + Returns: + `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross + from context. + """ + context = self.context_layer_norm(context) + latents = self.latents_layer_norm(latents) + batch_size, seq_length, embed_dim = context.shape[:3] + + # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! + # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` + q = self.q_proj(latents) + k = self.k_proj(torch.cat([context, latents], dim=-2)) + v = self.v_proj(torch.cat([context, latents], dim=-2)) + + # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) + # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] + # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) + q, k, v = [x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) for x in (q, k, v)] + + if self.qk_layer_norms: + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + + scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k) + stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach()) + attn = stabilized_scores.softmax(dim=-1) + + # Attend & project back to output... + resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v) + # einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads) + return self.output_proj(resampled.transpose(1, 2).flatten(-2)) + + +class IdeficsMLP(nn.Module): + def __init__(self, + prefix, + intermediate_size, + config: IdeficsConfig, + weights, + ): + """Simple MLP block with intermediate_size and embedding size""" + super().__init__() + self.embed_dim = config.vision_config.embed_dim + self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS) + self.fc = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.fc", weights=weights, bias=False, + ) + self.act = nn.ReLU() + self.c_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.c_proj", weights=weights, bias=False, + ) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.ln(hidden_states) + hidden_states = self.fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py new file mode 100644 index 00000000..e42f5c1f --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -0,0 +1,474 @@ +# coding=utf-8 +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.utils import ( + ModelOutput, + logging, +) +from transformers.models.idefics.configuration_idefics import IdeficsVisionConfig +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +@dataclass +class IdeficsVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics +class IdeficsVisionEmbeddings(nn.Module): + def __init__(self, prefix, config: IdeficsVisionConfig, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(weights.get_tensor(f"{prefix}.class_embedding")) + + self.patch_embedding = nn.Conv2d.load_no_bias( + prefix=f"{prefix}.patch_embedding", + weights=weights, + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + # self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.position_embedding = TensorParallelEmbedding( + prefix="model.vision_model.embeddings.position_embedding", weights=weights + ) + # self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + self.position_ids = weights.get_tensor(f"{prefix}.position_ids") + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision +class IdeficsVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + process_group = weights.process_group + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=True + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=True + ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=True + ) + self.out_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision +class IdeficsVisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc1", weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.fc2", weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision +class IdeficsVisionEncoderLayer(nn.Module): + def __init__(self, prefix, config: IdeficsVisionConfig, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = IdeficsVisionAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps + ) + self.mlp = IdeficsVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision +class IdeficsVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`IdeficsVisionEncoderLayer`]. + + Args: + config: IdeficsVisionConfig + """ + + def __init__(self, prefix, config: IdeficsVisionConfig, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + IdeficsVisionEncoderLayer(prefix=f"{prefix}.encoder.layers.{layer_id}", config=config, weights=weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + # self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # if self.gradient_checkpointing and self.training: + + # def create_custom_forward(module): + # def custom_forward(*inputs): + # return module(*inputs, output_attentions) + + # return custom_forward + + # layer_outputs = torch.utils.checkpoint.checkpoint( + # create_custom_forward(encoder_layer), + # hidden_states, + # attention_mask, + # causal_attention_mask, + # ) + # else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer +class IdeficsVisionTransformer(nn.Module): + def __init__(self, prefix, config: IdeficsVisionConfig, weights): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = IdeficsVisionEmbeddings(prefix=f"{prefix}.embeddings", config=config, weights=weights) + self.pre_layrnorm = nn.LayerNorm.load( + prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps + ) + self.encoder = IdeficsVisionEncoder(prefix=prefix, config=config, weights=weights) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps + ) + + # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py new file mode 100644 index 00000000..b30c5f80 --- /dev/null +++ b/server/text_generation_server/models/idefics.py @@ -0,0 +1,112 @@ +import torch +import torch.distributed + +from typing import List, Optional, Tuple + +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoProcessor, +) + +from text_generation_server.models import IdeficsCausalLM +from text_generation_server.models.custom_modeling.idefics_modeling import ( + IdeficsForVisionText2Text, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + + +class IDEFICSSharded(IdeficsCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 + self.device, self.dtype = device, dtype + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.vision_config.quantize = quantize + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + ) + + model = IdeficsForVisionText2Text(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(IdeficsCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + pixel_values: Optional = None, + image_attention_mask: Optional = None, + past_key_values: Optional = None, + ) -> Tuple[ + torch.Tensor, + List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + ]: + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_attention_mask=image_attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + return ( + outputs.logits, + outputs.past_key_values, + ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py new file mode 100644 index 00000000..74afc0a4 --- /dev/null +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -0,0 +1,837 @@ +import torch +import inspect +import re +from io import BytesIO +import base64 +from PIL import Image +import json + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin, IdeficsForVisionText2Text +from typing import Optional, Tuple, List, Type, Dict + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling + +tracer = trace.get_tracer(__name__) + + +# UTILS +def base64_to_pil(encoded_image): + decoded_image = base64.b64decode(encoded_image) + pil_image = Image.open(BytesIO(decoded_image)) + return pil_image + +def im_markdown_to_pil(im_markdown_str): + pattern = r'' + match = re.search(pattern, im_markdown_str) + img_b64_str = match.group(1) + return base64_to_pil(img_b64_str) + +def split_str_on_im_markdown(string_with_potential_im_markdown): + """ + Extract from a string (typically the user prompt string) the potentional images saved as a base64 representation + inside a markdown. + """ + pattern = r'' + parts = re.split(pattern, string_with_potential_im_markdown) + result = [] + for i, part in enumerate(parts): + if i % 2 == 0: + result.append(part) + else: + img_tag = f'' + result.append(img_tag) + return result + + +@dataclass +class IdeficsCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + pixel_values: Optional[torch.Tensor] + image_attention_mask: Optional[torch.Tensor] + past_key_values: Optional[List[Tuple]] + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor: ProcessorMixin, # Hack + dtype: torch.dtype, + device: torch.device, + ) -> "IdeficsCausalLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}") + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) #TODO: understand that + max_decode_tokens += stopping_criteria.max_new_tokens # TODO: I think it is just the maximum of tokens to generate in the WHOLE batch + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + prompts = [] + for inp in inputs: + # Each input is encoded into a list, where each element of this input list is either a string or a URL + from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}") + if isinstance(inp, str): + prompts.append([inp]) + elif isinstance(inp, list): + if not all(isinstance(item, str) for item in inp): + raise ValueError("All elements in the list must be strings (text string or image URL)") + prompts.append( + json.load(inp) + ) + else: + raise ValueError("Unsupported type of input") + # I initially wanted to send the images in string base64 but they are too big to send in a consistent way... + # So resorting to uploading the image to a server and pulling them back + # splitted_inp = split_str_on_im_markdown(inp) + # prompts.append( + # [ + # im_markdown_to_pil(s) if s.startswith(' Optional["IdeficsCausalLMBatch"]: + from loguru import logger; logger.info(f"filter in idefics_causal_lm.py") + # It deletes requests from the batch. For instance when client lost connection + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + position_ids = self.position_ids[keep_indices] + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + ] + # Do the same for pixel_values and image_attention_mask + pixel_values = self.pixel_values[keep_indices] + self.image_attention_mask = self.image_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.image_attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + : + ] + + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.pixel_values = pixel_values + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + return self + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch": + from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py") + # It adds new requests to the batch + # Used for padding + total_batch_size = 0 + max_input_length = 0 + max_num_images = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + max_num_images = max(max_num_images, batch.pixel_values.size(1)) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + max_tokens = 0 + + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + pixel_values = None + image_attention_mask = None + past_key_values = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset), + ) + + curr_batch_max_num_images = batch.pixel_values.size(1) + if pixel_values is None: + pixel_values = batch.pixel_values.new_zeros((total_batch_size, max_num_images, 3, 224, 224)) + pixel_values[start_index:end_index, :curr_batch_max_num_images] = batch.pixel_values + + if image_attention_mask is None: + image_attention_mask = batch.image_attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset, max_num_images) + ) + + # We need to slice the attention mask to remove padding from previous steps + # and to remove unused allocated space + left_offset = max_input_length - batch.max_input_length + batch_left_offset = ( + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset + ) + attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = batch.attention_mask[ + :, + batch_left_offset : -batch.padding_right_offset, + ] + from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}") + from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}") + image_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + :curr_batch_max_num_images + ] = batch.image_attention_mask[ + :, + batch_left_offset : - batch.padding_right_offset, + : + ] + + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((total_batch_size, 1)) + position_ids[start_index:end_index] = batch.position_ids + + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [ + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values + ] + elif len(batch.past_key_values[0][0].shape) == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(len(batch), -1, *t.shape[-2:]) + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + start_index = end_index + + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape + + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_input_length - 1, + head_dim, + ) + + if batches[0].keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape + else: + # seq_length is last for BLOOM + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_input_length - 1, + ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + for j in range(len(first_past_kvs)): + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if batch.keys_head_dim_last: + padded_past_keys[ + start_index:end_index, :, -past_seq_len:, : + ] = past_keys[:, :, -past_seq_len:, :] + else: + # BLOOM case + padded_past_keys[ + start_index:end_index, :, :, -past_seq_len: + ] = past_keys[:, :, :, -past_seq_len:] + del past_keys + + start_index = end_index + + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[ + start_index:end_index, :, -past_seq_len:, : + ] = past_values[:, :, -past_seq_len:, :] + del past_values + + # Update values + start_index = end_index + + past_key_values.append([padded_past_keys, padded_past_values]) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_attention_mask=image_attention_mask, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + ) + + def __len__(self): + return len(self.requests) + + +class IdeficsCausalLM(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = IdeficsForVisionText2Text.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": ""}) + + super(IdeficsCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[IdeficsCausalLMBatch]: + return IdeficsCausalLMBatch + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + pixel_values, + image_attention_mask, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_attention_mask": image_attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + "return_dict": True, + } + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + outputs = self.model.forward(**kwargs) + return outputs.logits, outputs.past_key_values + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: IdeficsCausalLMBatch + ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]: + from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter") + # slice the attention mask to the correct shape + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + if batch.input_ids.size(1) == 1: + # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), + # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension + # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated + # token need to attend to the encoder hidden states (i.e. the vision encoder) + # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic + image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around + else: + image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset] + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}") + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}") + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}") + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}") + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}") + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}") + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}") + + logits, past = self.forward( + input_ids=batch.input_ids, + attention_mask=attention_mask, + position_ids=batch.position_ids, + pixel_values=batch.pixel_values, + image_attention_mask=image_attention_mask, + past_key_values=batch.past_key_values, + ) + + # Results + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :, 0] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + next_token_id_squeezed.item() in self.all_special_ids, + generated_text, + ) + + generations.append(generation) + + # Update values + batch.input_ids[i, 0] = next_token_id + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}") + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + return generations, None + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}") + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + batch.image_attention_mask[:, -batch.padding_right_offset, :] = batch.image_attention_mask[:, -(batch.padding_right_offset+1), :] + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + # Update past key values + batch.past_key_values = past + + from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}") + return generations, batch diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 1cedc151..e7ed8875 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -14,6 +14,7 @@ from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor +from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): @@ -54,9 +55,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device - ) + if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) max_supported_total_tokens = self.model.warmup(batch) return generate_pb2.WarmupResponse( @@ -64,9 +70,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) async def Prefill(self, request, context): - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device - ) + if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9688f6e0..745c1d2e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -51,7 +51,31 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): ln.bias = None return ln +@classmethod +def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) + conv2d.weight = nn.Parameter(weight) + conv2d.bias = nn.Parameter(bias) + return conv2d + + +@classmethod +def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) + + conv2d.weight = nn.Parameter(weight) + conv2d.bias = None + return conv2d + + +torch.nn.Conv2d.load = load_conv2d +torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias