diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py
index 621652e8..518e1a15 100644
--- a/server/text_generation_server/models/__init__.py
+++ b/server/text_generation_server/models/__init__.py
@@ -18,6 +18,8 @@ from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
+from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
+from text_generation_server.models.idefics import IDEFICSSharded
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
@@ -40,6 +42,7 @@ __all__ = [
"OPTSharded",
"T5Sharded",
"get_model",
+ "IDEFICSSharded",
]
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
@@ -248,6 +251,14 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
+ elif model_type == "idefics":
+ return IDEFICSSharded(
+ model_id,
+ revision,
+ quantize=quantize,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ )
if sharded:
raise ValueError("sharded is not supported for AutoModel")
diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py
new file mode 100644
index 00000000..42c079d5
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py
@@ -0,0 +1,1318 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Idefics model."""
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import PreTrainedModel
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.modeling_utils import PretrainedConfig
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers import IdeficsConfig
+from text_generation_server.models.custom_modeling.idefics_vision import IdeficsVisionTransformer
+from text_generation_server.models.custom_modeling.idefics_perceiver import IdeficsPerceiverResampler
+from text_generation_server.utils.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ TensorParallelHead,
+ PositionRotaryEmbedding,
+ FastLinear,
+)
+
+
+# logger = logging.get_logger(__name__)
+
+# _CONFIG_FOR_DOC = "IdeficsConfig"
+
+# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [
+# "HuggingFaceM4/idefics-9b",
+# "HuggingFaceM4/idefics-80b",
+# # See all Idefics models at https://huggingface.co/models?filter=idefics
+# ]
+
+
+def expand_inputs_for_generation(
+ input_ids,
+ expand_size=1,
+ is_encoder_decoder=False,
+ attention_mask=None,
+ encoder_outputs=None,
+ **model_kwargs,
+):
+ expanded_return_idx = (
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
+ )
+ input_ids = input_ids.index_select(0, expanded_return_idx)
+
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
+
+ if attention_mask is not None:
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
+ model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
+ 0, expanded_return_idx
+ )
+ model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
+
+ if is_encoder_decoder:
+ if encoder_outputs is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
+ 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
+ )
+ model_kwargs["encoder_outputs"] = encoder_outputs
+ return input_ids, model_kwargs
+
+
+def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
+ # must have this key set to at least None
+ model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None)
+
+ # update past
+ if "past_key_values" in outputs:
+ model_kwargs["past"] = outputs.past_key_values
+ elif "mems" in outputs:
+ model_kwargs["past"] = outputs.mems
+ elif "past_buckets_states" in outputs:
+ model_kwargs["past"] = outputs.past_buckets_states
+ else:
+ model_kwargs["past"] = None
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+ # update attention masks
+ if not is_encoder_decoder:
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+ if "image_attention_mask" in model_kwargs:
+ image_attention_mask = model_kwargs["image_attention_mask"]
+ last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
+ model_kwargs["image_attention_mask"] = last_mask
+
+ return model_kwargs
+
+
+def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
+ token_type_ids = kwargs.get("token_type_ids", None)
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ pixel_values = kwargs.get("pixel_values", None)
+ image_attention_mask = kwargs.get("image_attention_mask", None)
+ # if pixel_values is None or image_attention_mask is None:
+ # raise ValueError("pixel values and image attention mask cannot be None")
+
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "use_cache": kwargs.get("use_cache"),
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ "pixel_values": pixel_values,
+ "image_attention_mask": image_attention_mask,
+ }
+
+
+def freeze_model(model, module_exceptions=[]):
+ mapping = {
+ "LayerNorm": nn.LayerNorm,
+ "Linear": nn.Linear,
+ "Embedding": nn.Embedding,
+ }
+ module_exceptions_mapped = [mapping[m] for m in module_exceptions]
+ for module in model.modules():
+ if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
+ module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes
+ else:
+ module.requires_grad_(False)
+ return model
+
+
+class IdeficsDecoupledPartialTPEmbedding(nn.Module):
+ def __init__(
+ self,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_embeddings = config.vocab_size
+ self.weight = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights)
+ self.additional_weight = nn.Parameter(weights.get_tensor(f"model.embed_tokens.additional_embedding.weight"))
+
+ def forward(self, input_ids):
+ # Clone so that we don't modify the original input_ids later on
+ input_ids = input_ids.clone()
+ additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
+ additional_embeddings = torch.nn.functional.embedding(input_ids_additional_vocab - self.num_embeddings, self.additional_weight)
+
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
+ input_ids[additional_vocab_indices] = 0
+ full_vector = self.weight(input_ids)
+
+ # overwrite the records with high indices
+ full_vector[additional_vocab_indices] = additional_embeddings
+
+ return full_vector
+
+
+class IdeficsDecoupledTensorParallelLinear(nn.Module):
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
+ """
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
+ then it will create `out_additional_features * in_features` additional parameters that are always trained. If
+ `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
+ """
+
+ def __init__(
+ self,
+ config,
+ weights,
+ ) -> None:
+ super().__init__()
+ self.fc = TensorParallelHead.load(
+ config=config, prefix="lm_head", weights=weights
+ )
+ self.additional_fc = FastLinear.load(
+ config=config, prefix="lm_head.additional_fc", weights=weights, bias=False,
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ output = self.fc(input)
+ additional_features = self.additional_fc(input)
+ output = torch.cat((output, additional_features), -1)
+
+ return output
+
+ def extra_repr(self) -> str:
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
+ return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
+ self.in_features,
+ self.out_features,
+ self.out_additional_features,
+ self.bias is not None,
+ self.partially_freeze,
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# this was adapted from LlamaRMSNorm
+class IdeficsRMSNorm(nn.Module):
+ def __init__(self, prefix, weights, eps=1e-6):
+ """
+ IdeficsRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+
+ weight = weights.get_tensor(f"{prefix}.weight")
+ self.weight = nn.Parameter(weight)
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+# this was adapted from LlamaRotaryEmbedding
+class IdeficsEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# this was adapted from LlamaMLP
+class IdeficsMLP(nn.Module):
+ def __init__(
+ self,
+ config,
+ prefix,
+ weights,
+ ):
+ super().__init__()
+ self.gate_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.gate_proj", weights=weights, bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.down_proj", weights=weights, bias=False,
+ )
+ self.up_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.up_proj", weights=weights, bias=False,
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+# this was adapted from LlamaAttention
+class IdeficsAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config,
+ prefix,
+ weights,
+ qk_layer_norms: bool = False,
+ is_cross_attention: bool = False,
+ ):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.dropout = config.dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.is_cross_attention = is_cross_attention
+
+ # if not hasattr(nn.functional, "scaled_dot_product_attention"):
+ # raise ValueError("this model requires pytorch 2.0 or higher")
+
+ process_group = weights.process_group
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+
+ if self.is_cross_attention:
+ # kv_input_dim = (
+ # self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim
+ # )
+ self.q_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
+ )
+ self.k_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
+ )
+ else:
+ self.q_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
+ )
+ self.k_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
+ )
+ # self.rotary_emb = PositionRotaryEmbedding.load(
+ # prefix=f"{prefix}.rotary_emb", weights=weights
+ # )
+ self.rotary_emb = IdeficsEmbedding(self.head_dim, device="cuda:0") #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
+
+ self.qk_layer_norms = qk_layer_norms
+ if self.qk_layer_norms:
+ self.q_layer_norm = IdeficsRMSNorm(
+ prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.k_layer_norm = IdeficsRMSNorm(
+ prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ if not is_cross_attention:
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ else:
+ _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
+ key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = (
+ self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ if not is_cross_attention:
+ cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # cos, sin = self.rotary_emb.get_cos_sin(
+ # position_ids=torch.arange(),
+ # max_s=max(kv_seq_len, q_len),
+ # dtype=hidden_states.dtype,
+ # )
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ if self.qk_layer_norms:
+ query_states = self.q_layer_norm(query_states)
+ key_states = self.k_layer_norm(key_states)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ attn_output = nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ attn_weights = None
+ if output_attentions:
+ logger.warning_once(
+ "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
+ )
+
+ return attn_output, attn_weights, past_key_value
+
+
+# this was adapted from LlamaDecoderLayer
+class IdeficsDecoderLayer(nn.Module):
+ def __init__(self, layer_id: int, config: IdeficsConfig, weights):
+ super().__init__()
+ self.process_group = weights.process_group
+ self.hidden_size = config.hidden_size
+ prefix = f"model.layers.{layer_id}"
+ self.self_attn = IdeficsAttention(
+ config=config,
+ prefix=f"{prefix}.self_attn",
+ weights=weights,
+ qk_layer_norms=False,
+ is_cross_attention=False,
+ )
+ self.mlp = IdeficsMLP(
+ config=config,
+ prefix=f"{prefix}.mlp",
+ weights=weights,
+ )
+ self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps)
+ self.dropout = config.dropout
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class IdeficsGatedCrossAttentionLayer(nn.Module):
+ def __init__(self, layer_id, config: IdeficsConfig, weights):
+ super().__init__()
+ self.process_group = weights.process_group
+ self.hidden_size = config.hidden_size
+ prefix = f"model.gated_cross_attn_layers.{layer_id}"
+ self.cross_attn = IdeficsAttention(
+ config=config,
+ prefix=f"{prefix}.cross_attn",
+ weights=weights,
+ qk_layer_norms=True,
+ is_cross_attention=True,
+ )
+ self.mlp = IdeficsMLP(
+ config=config,
+ prefix=f"{prefix}.mlp",
+ weights=weights,
+ )
+ self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps)
+ self.config = config.dropout
+
+ self.act_cross_attn = nn.Tanh()
+ self.act_dense = nn.Tanh()
+
+ self.alpha_cross_attn = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_cross_attn"))
+ self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense"))
+
+ if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
+ raise ValueError("Alpha parameters not initialized correctly!")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_hidden_states: Optional[torch.Tensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ no_images: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored
+ """
+ if image_hidden_states is None:
+ raise ValueError(
+ "`image_hidden_states` is required for Idefics cross attention module which are visual features to be"
+ " conditioned on."
+ )
+
+ if past_key_value is not None:
+ raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.")
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.cross_attn(
+ hidden_states=hidden_states,
+ key_value_states=image_hidden_states,
+ attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
+ # when there are no images the model is used in pure language mode
+ gate = 0 if no_images else 1
+ hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
+ hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`IdeficsConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+# @add_start_docstrings(
+# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+# LLAMA_START_DOCSTRING,
+# )
+class IdeficsPreTrainedModel(PreTrainedModel):
+ config_class = IdeficsConfig
+ # base_model_prefix = "model"
+ # supports_gradient_checkpointing = True
+ # _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
+
+ # def _init_weights(self, module):
+ # # important: this ported version of Idefics isn't meant for training from scratch - only
+ # # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
+ # # base should be used for training from scratch and it contains the correct code.
+ # std = self.config.initializer_range
+ # if isinstance(module, nn.Linear):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.bias is not None:
+ # module.bias.data.zero_()
+ # elif isinstance(module, nn.Embedding):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.padding_idx is not None:
+ # module.weight.data[module.padding_idx].zero_()
+
+ # def _set_gradient_checkpointing(self, module, value=False):
+ # if isinstance(module, IdeficsModel):
+ # module.gradient_checkpointing = value
+
+
+# LLAMA_INPUTS_DOCSTRING = r"""
+# Args:
+# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+# it.
+
+# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+# [`PreTrainedTokenizer.__call__`] for details.
+
+# [What are input IDs?](../glossary#input-ids)
+# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+# - 1 for tokens that are **not masked**,
+# - 0 for tokens that are **masked**.
+
+# [What are attention masks?](../glossary#attention-mask)
+
+# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+# [`PreTrainedTokenizer.__call__`] for details.
+
+# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+# `past_key_values`).
+
+# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+# information on the default strategy.
+
+# - 1 indicates the head is **not masked**,
+# - 0 indicates the head is **masked**.
+# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+# `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+# model's internal embedding lookup matrix.
+# use_cache (`bool`, *optional*):
+# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+# `past_key_values`).
+# output_attentions (`bool`, *optional*):
+# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+# tensors for more detail.
+# output_hidden_states (`bool`, *optional*):
+# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+# more detail.
+# return_dict (`bool`, *optional*):
+# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+# """
+
+
+# @add_start_docstrings(
+# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+# LLAMA_START_DOCSTRING,
+# )
+class IdeficsModel(IdeficsPreTrainedModel):
+ # """
+ # Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
+
+ # Args:
+ # config: IdeficsConfig
+ # """
+
+ def __init__(self, config: IdeficsConfig, weights):
+ super().__init__(config)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = IdeficsDecoupledPartialTPEmbedding(
+ config=config,
+ weights=weights,
+ )
+
+ self.image_size = config.vision_config.image_size
+ self.vision_config = config.vision_config
+ self.vision_model = IdeficsVisionTransformer(
+ prefix="model.vision_model",
+ config=config.vision_config,
+ weights=weights,
+ )
+
+ # Perceiver Resampler
+ if config.use_resampler:
+ perceiver_config = config.perceiver_config
+ self.perceiver_resampler = IdeficsPerceiverResampler(
+ prefix=f"model.perceiver_resampler",
+ config=config,
+ embed_dim=config.vision_config.embed_dim,
+ depth=perceiver_config.resampler_depth,
+ n_heads=perceiver_config.resampler_n_heads,
+ head_dim=perceiver_config.resampler_head_dim,
+ n_latents=perceiver_config.resampler_n_latents,
+ weights=weights,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ IdeficsDecoderLayer(layer_id, config, weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ self.cross_layer_interval = config.cross_layer_interval
+ num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
+ self.gated_cross_attn_layers = nn.ModuleList(
+ [
+ IdeficsGatedCrossAttentionLayer(layer_id, config, weights)
+ for layer_id in range(num_cross_layers)]
+ )
+ # self.gradient_checkpointing = False
+
+ self.norm = IdeficsRMSNorm(prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps)
+
+ # self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ # self.post_init()
+
+ # self.freeze_relevant_params(config)
+
+ # def freeze_relevant_params(self, config=None):
+ # if config is None:
+ # config = self.config
+
+ # if config.freeze_text_layers:
+ # self.freeze_text_layers(config.freeze_text_module_exceptions)
+
+ # if config.freeze_vision_layers:
+ # freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
+
+ # def freeze_text_layers(self, module_exceptions=[]):
+ # for module in [self.layers, self.norm]:
+ # freeze_model(module, module_exceptions=module_exceptions)
+
+ # def freeze_vision_layers(self, module_exceptions=[]):
+ # freeze_model(self.vision_model, module_exceptions=module_exceptions)
+
+ # def get_input_embeddings(self):
+ # return self.embed_tokens
+
+ # def set_input_embeddings(self, value):
+ # self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ elif position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ no_images = False
+ if pixel_values is None and image_embeddings is None:
+ raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
+
+ elif pixel_values is not None and image_embeddings is not None:
+ raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
+
+ elif pixel_values is not None:
+ no_images = len(torch.nonzero(pixel_values)) == 0
+ pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
+ batch_size, num_images = pixel_values.shape[:2]
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
+
+ elif image_embeddings is not None:
+ batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
+ image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
+ image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
+
+ if self.config.use_resampler:
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
+ image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
+ image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
+ # # Hack to use the model in full language modeling mode
+ # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
+ # Make image_attention_mask compatible with hidden states
+ text_seq_len = image_attention_mask.size(1)
+ image_attention_mask = image_attention_mask.unsqueeze(-1)
+ image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
+ image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
+
+ if image_hidden_states is not None:
+ image_batch_size, image_sequence_length, _ = image_hidden_states.size()
+ image_hidden_shape = (image_batch_size, image_sequence_length)
+ if image_attention_mask is None:
+ image_attention_mask = torch.ones(image_hidden_shape, device=device)
+ image_attention_mask = self.invert_attention_mask(image_attention_mask)
+ else:
+ image_attention_mask = None
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ # if self.gradient_checkpointing and self.training:
+ # if use_cache:
+ # logger.warning_once(
+ # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ # )
+ # use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ def vblock(
+ main_block,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ image_hidden_states,
+ image_attention_mask,
+ output_attentions,
+ use_cache,
+ no_images,
+ layer_idx,
+ cross_layer_interval,
+ gated_cross_attn_layers,
+ ):
+ # TODO(ls): Add cross attention values to respective lists
+ if layer_idx % cross_layer_interval == 0:
+ xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
+ outputs = xblock(
+ hidden_states,
+ attention_mask=attention_mask,
+ image_hidden_states=image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ past_key_value=None, # not implemented
+ no_images=no_images,
+ )
+ hidden_states = outputs[0]
+
+ layer_outputs = main_block(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ return layer_outputs
+
+ # if self.gradient_checkpointing and self.training:
+ # past_key_value = None
+ # if use_cache:
+ # logger.warning_once(
+ # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ # )
+ # use_cache = False
+
+ # layer_outputs = torch.utils.checkpoint.checkpoint(
+ # vblock,
+ # decoder_layer,
+ # hidden_states,
+ # attention_mask,
+ # position_ids,
+ # past_key_value,
+ # image_hidden_states,
+ # image_attention_mask,
+ # output_attentions,
+ # use_cache,
+ # no_images,
+ # idx,
+ # self.cross_layer_interval,
+ # self.gated_cross_attn_layers,
+ # )
+ # else:
+ layer_outputs = vblock(
+ decoder_layer,
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ image_hidden_states=image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ no_images=no_images,
+ layer_idx=idx,
+ cross_layer_interval=self.cross_layer_interval,
+ gated_cross_attn_layers=self.gated_cross_attn_layers,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
+ def __init__(
+ self,
+ config,
+ weights,
+ ):
+ super().__init__(config)
+ self.model = IdeficsModel(
+ config=config,
+ weights=weights,
+ )
+
+ self.lm_head = IdeficsDecoupledTensorParallelLinear(
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ from loguru import logger; logger.info(f"forward in idefics_modeling.py - {input_ids.size()=}")
+ from loguru import logger; logger.info(f"forward in idefics_modeling.py - {attention_mask.size()=}")
+ from loguru import logger; logger.info(f"forward in idefics_modeling.py - {position_ids.size()=}")
+ from loguru import logger; logger.info(f"forward in idefics_modeling.py - {pixel_values.size()=} {pixel_values.sum()=}")
+ from loguru import logger; logger.info(f"forward in idefics_modeling.py - {image_attention_mask.size()=} {image_attention_mask.sum()=}")
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ image_embeddings=image_embeddings,
+ image_attention_mask=image_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ # if labels is not None:
+ # # Shift so that tokens < n predict n
+ # if attention_mask is not None:
+ # shift_attention_mask = attention_mask[..., 1:]
+ # shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
+ # shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
+ # else:
+ # shift_logits = logits[..., :-1, :].contiguous()
+ # shift_labels = labels[..., 1:].contiguous()
+ # # Flatten the tokens
+ # loss_fct = CrossEntropyLoss()
+ # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ # if not return_dict:
+ # output = (logits,) + outputs[1:]
+ # return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
+ unwanted_kwargs = ["token_type_ids"]
+ for kwarg in unwanted_kwargs:
+ inputs.pop(kwarg, None)
+ return inputs
+
+ @staticmethod
+ def _expand_inputs_for_generation(
+ *args,
+ **model_kwargs,
+ ):
+ return expand_inputs_for_generation(*args, **model_kwargs)
+
+ @staticmethod
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
+ return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
+
+ @staticmethod
+ def _reorder_cache(past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
\ No newline at end of file
diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
new file mode 100644
index 00000000..c946ee7b
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
@@ -0,0 +1,246 @@
+# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
+#
+# MIT License
+#
+# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+"""
+
+Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
+time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
+that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
+prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
+to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
+
+References:
+ - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
+ - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
+
+"""
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from transformers import IdeficsConfig
+from text_generation_server.utils.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+EPS=1e-5
+
+class IdeficsPerceiverResampler(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config: IdeficsConfig,
+ embed_dim: int,
+ depth: int,
+ n_heads: int,
+ head_dim: int,
+ n_latents: int,
+ weights,
+ ) -> None:
+ """
+ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
+ MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
+ returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
+ to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
+ Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
+
+ Args:
+ config (`IdeficsConfig`): config object
+ embed_dim (`int`): The size of each embedding vector
+ depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
+ n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
+ head_dim (`int`): Dimensionality of each head projection in the Transformer block.
+ n_latents (`int`):
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
+
+ """
+ super().__init__()
+ self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
+ self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
+
+ # Create Latents for Perceiver
+ self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents"))
+
+ self.intermediate_dim = (
+ self.embed_dim * 4
+ if not hasattr(config.vision_config, "embed_dim")
+ else config.vision_config.embed_dim * 4
+ )
+ # Create Transformer Blocks
+ self.blocks = nn.ModuleList(
+ [
+ nn.ModuleList(
+ [
+ IdeficsPerceiverAttention(
+ prefix=f"{prefix}.blocks.{layer_id}.0",
+ config=config,
+ embed_dim=self.embed_dim,
+ n_heads=self.n_heads,
+ head_dim=self.head_dim,
+ qk_layer_norms=self.qk_layer_norms,
+ weights=weights,
+ ),
+ IdeficsMLP(
+ prefix=f"{prefix}.blocks.{layer_id}.1",
+ intermediate_size=self.intermediate_dim,
+ config=config,
+ weights=weights
+ ),
+ ]
+ )
+ for layer_id in range(depth)
+ ]
+ )
+ self.layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS)
+
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
+ """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
+ # einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
+ latents = self.latents.repeat(context.shape[0], 1, 1)
+
+ # Feed through Perceiver Attention blocks...
+ for attn, ff in self.blocks:
+ latents = attn(context, latents) + latents
+ latents = ff(latents) + latents
+
+ return self.layer_norm(latents)
+
+
+class IdeficsPerceiverAttention(nn.Module):
+ def __init__(self,
+ prefix,
+ config,
+ embed_dim: int,
+ n_heads: int,
+ head_dim: int,
+ qk_layer_norms: bool,
+ weights
+ ) -> None:
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
+ super().__init__()
+ self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
+ self.qk_layer_norms = qk_layer_norms
+ # Normalization & Scaling
+ self.context_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS)
+ self.latents_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS)
+ if self.qk_layer_norms:
+ self.q_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS)
+ self.k_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS)
+
+ self.qk_scale = self.head_dim**-0.5
+
+ process_group = weights.process_group
+ if n_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+
+ # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
+ self.q_proj = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
+ )
+ self.k_proj = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
+ )
+
+ self.output_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False
+ )
+
+ def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
+ """
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
+
+ Args:
+ context (`torch.Tensor`):
+ Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
+ latents (`torch.Tensor`):
+ Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
+
+ Returns:
+ `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
+ from context.
+ """
+ context = self.context_layer_norm(context)
+ latents = self.latents_layer_norm(latents)
+ batch_size, seq_length, embed_dim = context.shape[:3]
+
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
+ q = self.q_proj(latents)
+ k = self.k_proj(torch.cat([context, latents], dim=-2))
+ v = self.v_proj(torch.cat([context, latents], dim=-2))
+
+ # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
+ # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
+ # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
+ q, k, v = [x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) for x in (q, k, v)]
+
+ if self.qk_layer_norms:
+ q = self.q_layer_norm(q)
+ k = self.k_layer_norm(k)
+
+ scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
+ stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
+ attn = stabilized_scores.softmax(dim=-1)
+
+ # Attend & project back to output...
+ resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
+ # einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
+ return self.output_proj(resampled.transpose(1, 2).flatten(-2))
+
+
+class IdeficsMLP(nn.Module):
+ def __init__(self,
+ prefix,
+ intermediate_size,
+ config: IdeficsConfig,
+ weights,
+ ):
+ """Simple MLP block with intermediate_size and embedding size"""
+ super().__init__()
+ self.embed_dim = config.vision_config.embed_dim
+ self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
+ self.fc = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.fc", weights=weights, bias=False,
+ )
+ self.act = nn.ReLU()
+ self.c_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.c_proj", weights=weights, bias=False,
+ )
+
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.ln(hidden_states)
+ hidden_states = self.fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+
+ return hidden_states
diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py
new file mode 100644
index 00000000..e42f5c1f
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py
@@ -0,0 +1,474 @@
+# coding=utf-8
+# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.utils import (
+ ModelOutput,
+ logging,
+)
+from transformers.models.idefics.configuration_idefics import IdeficsVisionConfig
+from text_generation_server.utils.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+)
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class IdeficsVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
+class IdeficsVisionEmbeddings(nn.Module):
+ def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(weights.get_tensor(f"{prefix}.class_embedding"))
+
+ self.patch_embedding = nn.Conv2d.load_no_bias(
+ prefix=f"{prefix}.patch_embedding",
+ weights=weights,
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ # self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.position_embedding = TensorParallelEmbedding(
+ prefix="model.vision_model.embeddings.position_embedding", weights=weights
+ )
+ # self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+ self.position_ids = weights.get_tensor(f"{prefix}.position_ids")
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision
+class IdeficsVisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ process_group = weights.process_group
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+
+ self.k_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
+ )
+ self.q_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
+class IdeficsVisionMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.fc1", weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.fc2", weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision
+class IdeficsVisionEncoderLayer(nn.Module):
+ def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = IdeficsVisionAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights)
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
+ )
+ self.mlp = IdeficsVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision
+class IdeficsVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`IdeficsVisionEncoderLayer`].
+
+ Args:
+ config: IdeficsVisionConfig
+ """
+
+ def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ IdeficsVisionEncoderLayer(prefix=f"{prefix}.encoder.layers.{layer_id}", config=config, weights=weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ # self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # if self.gradient_checkpointing and self.training:
+
+ # def create_custom_forward(module):
+ # def custom_forward(*inputs):
+ # return module(*inputs, output_attentions)
+
+ # return custom_forward
+
+ # layer_outputs = torch.utils.checkpoint.checkpoint(
+ # create_custom_forward(encoder_layer),
+ # hidden_states,
+ # attention_mask,
+ # causal_attention_mask,
+ # )
+ # else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
+class IdeficsVisionTransformer(nn.Module):
+ def __init__(self, prefix, config: IdeficsVisionConfig, weights):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = IdeficsVisionEmbeddings(prefix=f"{prefix}.embeddings", config=config, weights=weights)
+ self.pre_layrnorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
+ )
+ self.encoder = IdeficsVisionEncoder(prefix=prefix, config=config, weights=weights)
+ self.post_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps
+ )
+
+ # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py
new file mode 100644
index 00000000..b30c5f80
--- /dev/null
+++ b/server/text_generation_server/models/idefics.py
@@ -0,0 +1,112 @@
+import torch
+import torch.distributed
+
+from typing import List, Optional, Tuple
+
+from transformers import (
+ AutoTokenizer,
+ AutoConfig,
+ AutoProcessor,
+)
+
+from text_generation_server.models import IdeficsCausalLM
+from text_generation_server.models.custom_modeling.idefics_modeling import (
+ IdeficsForVisionText2Text,
+)
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+)
+
+
+class IDEFICSSharded(IdeficsCausalLM):
+ def __init__(
+ self,
+ model_id: str,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ ):
+ self.process_group, rank, world_size = initialize_torch_distributed()
+ if torch.cuda.is_available():
+ device = torch.device(f"cuda:{rank}")
+ dtype = torch.float16 if dtype is None else dtype
+ else:
+ device = torch.device("cpu")
+ dtype = torch.float32
+ self.device, self.dtype = device, dtype
+
+ config = AutoConfig.from_pretrained(
+ model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ )
+ config.quantize = quantize
+ config.vision_config.quantize = quantize
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ self.processor = AutoProcessor.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+
+ torch.distributed.barrier(group=self.process_group)
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(
+ filenames,
+ device=device,
+ dtype=dtype,
+ process_group=self.process_group,
+ )
+
+ model = IdeficsForVisionText2Text(config, weights)
+
+ torch.distributed.barrier(group=self.process_group)
+ super(IdeficsCausalLM, self).__init__(
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ world_size=world_size,
+ )
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ pixel_values: Optional = None,
+ image_attention_mask: Optional = None,
+ past_key_values: Optional = None,
+ ) -> Tuple[
+ torch.Tensor,
+ List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
+ ]:
+ # Model Forward
+ outputs = self.model.forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_attention_mask=image_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=True,
+ )
+
+ return (
+ outputs.logits,
+ outputs.past_key_values,
+ )
diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py
new file mode 100644
index 00000000..74afc0a4
--- /dev/null
+++ b/server/text_generation_server/models/idefics_causal_lm.py
@@ -0,0 +1,837 @@
+import torch
+import inspect
+import re
+from io import BytesIO
+import base64
+from PIL import Image
+import json
+
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin, IdeficsForVisionText2Text
+from typing import Optional, Tuple, List, Type, Dict
+
+from text_generation_server.models import Model
+from text_generation_server.models.types import (
+ Batch,
+ PrefillTokens,
+ Generation,
+ GeneratedText,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
+
+tracer = trace.get_tracer(__name__)
+
+
+# UTILS
+def base64_to_pil(encoded_image):
+ decoded_image = base64.b64decode(encoded_image)
+ pil_image = Image.open(BytesIO(decoded_image))
+ return pil_image
+
+def im_markdown_to_pil(im_markdown_str):
+ pattern = r'
'
+ match = re.search(pattern, im_markdown_str)
+ img_b64_str = match.group(1)
+ return base64_to_pil(img_b64_str)
+
+def split_str_on_im_markdown(string_with_potential_im_markdown):
+ """
+ Extract from a string (typically the user prompt string) the potentional images saved as a base64 representation
+ inside a markdown.
+ """
+ pattern = r'
'
+ parts = re.split(pattern, string_with_potential_im_markdown)
+ result = []
+ for i, part in enumerate(parts):
+ if i % 2 == 0:
+ result.append(part)
+ else:
+ img_tag = f'
'
+ result.append(img_tag)
+ return result
+
+
+@dataclass
+class IdeficsCausalLMBatch(Batch):
+ batch_id: int
+ requests: List[generate_pb2.Request]
+ requests_idx_mapping: Dict[int, int]
+
+ # Decoder values
+ input_ids: torch.Tensor
+ attention_mask: torch.Tensor
+ position_ids: torch.Tensor
+ pixel_values: Optional[torch.Tensor]
+ image_attention_mask: Optional[torch.Tensor]
+ past_key_values: Optional[List[Tuple]]
+
+ # All tokens
+ all_input_ids: List[torch.Tensor]
+
+ # Lengths of all generations present in the batch
+ input_lengths: List[int]
+ prefix_offsets: List[int]
+ read_offsets: List[int]
+
+ # Generation helpers
+ next_token_choosers: List[NextTokenChooser]
+ stopping_criterias: List[StoppingCriteria]
+
+ # Metadata used for padding
+ max_input_length: int
+ padding_right_offset: int
+
+ # Maximum number of tokens this batch will grow to
+ max_tokens: int
+
+ # Past metadata
+ keys_head_dim_last: bool = True
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.max_tokens,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ processor: ProcessorMixin, # Hack
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "IdeficsCausalLMBatch":
+ inputs = []
+ next_token_choosers = []
+ stopping_criterias = []
+ prefix_offsets = []
+ read_offsets = []
+ requests_idx_mapping = {}
+
+ # Parse batch
+ max_truncation = 0
+ padding_right_offset = 0
+ max_decode_tokens = 0
+ for i, r in enumerate(pb.requests):
+ from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
+ requests_idx_mapping[r.id] = i
+ inputs.append(r.inputs)
+ next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ stopping_criterias.append(stopping_criteria)
+ max_truncation = max(max_truncation, r.truncate) #TODO: understand that
+ max_decode_tokens += stopping_criteria.max_new_tokens # TODO: I think it is just the maximum of tokens to generate in the WHOLE batch
+ padding_right_offset = max(
+ padding_right_offset, stopping_criteria.max_new_tokens
+ )
+
+ prompts = []
+ for inp in inputs:
+ # Each input is encoded into a list, where each element of this input list is either a string or a URL
+ from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
+ if isinstance(inp, str):
+ prompts.append([inp])
+ elif isinstance(inp, list):
+ if not all(isinstance(item, str) for item in inp):
+ raise ValueError("All elements in the list must be strings (text string or image URL)")
+ prompts.append(
+ json.load(inp)
+ )
+ else:
+ raise ValueError("Unsupported type of input")
+ # I initially wanted to send the images in string base64 but they are too big to send in a consistent way...
+ # So resorting to uploading the image to a server and pulling them back
+ # splitted_inp = split_str_on_im_markdown(inp)
+ # prompts.append(
+ # [
+ # im_markdown_to_pil(s) if s.startswith('
Optional["IdeficsCausalLMBatch"]:
+ from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
+ # It deletes requests from the batch. For instance when client lost connection
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ if len(request_ids) == len(self):
+ return self
+
+ keep_indices = []
+
+ # New values after filtering
+ requests_idx_mapping = {}
+ requests = []
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ max_input_length = 0
+
+ next_token_choosers = []
+ stopping_criterias = []
+
+ total_remaining_decode_tokens = 0
+ new_padding_right_offset = 0
+
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ requests_idx_mapping[request_id] = i
+ keep_indices.append(idx)
+
+ requests.append(self.requests[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
+ all_input_ids.append(self.all_input_ids[idx])
+
+ request_input_length = self.input_lengths[idx]
+ input_lengths.append(request_input_length)
+ max_input_length = max(max_input_length, request_input_length)
+
+ next_token_choosers.append(self.next_token_choosers[idx])
+ stopping_criteria = self.stopping_criterias[idx]
+ stopping_criterias.append(stopping_criteria)
+ remaining_decode_tokens = (
+ stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
+ )
+ total_remaining_decode_tokens += remaining_decode_tokens
+ new_padding_right_offset = max(
+ new_padding_right_offset, remaining_decode_tokens
+ )
+
+ # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
+ input_ids = self.input_ids[keep_indices]
+ position_ids = self.position_ids[keep_indices]
+ self.attention_mask = self.attention_mask[
+ keep_indices,
+ -(self.padding_right_offset + max_input_length) : (
+ self.attention_mask.shape[1] - self.padding_right_offset
+ )
+ + new_padding_right_offset,
+ ]
+ # Do the same for pixel_values and image_attention_mask
+ pixel_values = self.pixel_values[keep_indices]
+ self.image_attention_mask = self.image_attention_mask[
+ keep_indices,
+ -(self.padding_right_offset + max_input_length) : (
+ self.image_attention_mask.shape[1] - self.padding_right_offset
+ )
+ + new_padding_right_offset,
+ :
+ ]
+
+ # Ensure that past_key_values tensors can be updated in-place
+ if type(self.past_key_values[0]) == tuple:
+ self.past_key_values = [list(layer) for layer in self.past_key_values]
+
+ # Update tensors in-place to allow incremental garbage collection
+ past_kv_length = max_input_length - 1
+ for layer in self.past_key_values:
+ past_keys, past_values = layer
+ if len(past_keys.shape) == 3:
+ # Force past to be of dim [self_size, num_heads, ...] for easy indexing
+ past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
+ past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
+ if self.keys_head_dim_last:
+ layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
+ else:
+ layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
+ del past_keys
+ layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
+ del past_values
+
+ max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
+
+ self.requests = requests
+ self.requests_idx_mapping = requests_idx_mapping
+ self.input_ids = input_ids
+ self.pixel_values = pixel_values
+ self.position_ids = position_ids
+ self.all_input_ids = all_input_ids
+ self.input_lengths = input_lengths
+ self.prefix_offsets = prefix_offsets
+ self.read_offsets = read_offsets
+ self.next_token_choosers = next_token_choosers
+ self.stopping_criterias = stopping_criterias
+ self.max_input_length = max_input_length
+ self.padding_right_offset = new_padding_right_offset
+ self.max_tokens = max_tokens
+
+ return self
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
+ from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
+ # It adds new requests to the batch
+ # Used for padding
+ total_batch_size = 0
+ max_input_length = 0
+ max_num_images = 0
+ padding_right_offset = 0
+ for batch in batches:
+ total_batch_size += len(batch)
+ max_input_length = max(max_input_length, batch.max_input_length)
+ max_num_images = max(max_num_images, batch.pixel_values.size(1))
+ padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
+
+ # Batch attributes
+ requests = []
+ requests_idx_mapping = {}
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ next_token_choosers = []
+ stopping_criterias = []
+ max_tokens = 0
+
+ # Batch tensors
+ input_ids = None
+ attention_mask = None
+ position_ids = None
+ pixel_values = None
+ image_attention_mask = None
+ past_key_values = []
+
+ # Used for slicing correctly inside the tensors
+ # Equivalent to a cumsum on batch sizes
+ start_index = 0
+ for i, batch in enumerate(batches):
+ requests.extend(batch.requests)
+ input_lengths.extend(batch.input_lengths)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
+ all_input_ids.extend(batch.all_input_ids)
+ next_token_choosers.extend(batch.next_token_choosers)
+ stopping_criterias.extend(batch.stopping_criterias)
+
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ # We need to offset the mapping for each batch by the cumulative batch size
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + start_index
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+
+ # We only concatenate batches that did at least one step
+ if batch.past_key_values is None:
+ raise ValueError("only concatenate prefilled batches")
+
+ # Create empty tensor
+ # input_ids is always of shape [batch_size, 1]
+ # We do not need to pad it
+ if input_ids is None:
+ input_ids = batch.input_ids.new_empty((total_batch_size, 1))
+ # Copy to correct indices
+ input_ids[start_index:end_index] = batch.input_ids
+
+ # Create padded tensor
+ if attention_mask is None:
+ attention_mask = batch.attention_mask.new_zeros(
+ (total_batch_size, max_input_length + padding_right_offset),
+ )
+
+ curr_batch_max_num_images = batch.pixel_values.size(1)
+ if pixel_values is None:
+ pixel_values = batch.pixel_values.new_zeros((total_batch_size, max_num_images, 3, 224, 224))
+ pixel_values[start_index:end_index, :curr_batch_max_num_images] = batch.pixel_values
+
+ if image_attention_mask is None:
+ image_attention_mask = batch.image_attention_mask.new_zeros(
+ (total_batch_size, max_input_length + padding_right_offset, max_num_images)
+ )
+
+ # We need to slice the attention mask to remove padding from previous steps
+ # and to remove unused allocated space
+ left_offset = max_input_length - batch.max_input_length
+ batch_left_offset = (
+ batch.attention_mask.shape[1]
+ - batch.max_input_length
+ - batch.padding_right_offset
+ )
+ attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ ] = batch.attention_mask[
+ :,
+ batch_left_offset : -batch.padding_right_offset,
+ ]
+ from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
+ from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
+ image_attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ :curr_batch_max_num_images
+ ] = batch.image_attention_mask[
+ :,
+ batch_left_offset : - batch.padding_right_offset,
+ :
+ ]
+
+ # Create empty tensor
+ # position_ids is always of shape [batch_size, 1]
+ if position_ids is None:
+ position_ids = batch.position_ids.new_empty((total_batch_size, 1))
+ position_ids[start_index:end_index] = batch.position_ids
+
+ # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
+ # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
+ # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
+ # And ensure that we can update tensors in-place
+ if type(batch.past_key_values[0]) == tuple:
+ batch.past_key_values = [
+ [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
+ for layer in batch.past_key_values
+ ]
+ elif len(batch.past_key_values[0][0].shape) == 3:
+ for layer in batch.past_key_values:
+ for k, t in enumerate(layer):
+ layer[k] = t.view(len(batch), -1, *t.shape[-2:])
+
+ # Add eventual padding tokens that were added while concatenating
+ max_tokens += batch.max_tokens + (
+ max_input_length - batch.max_input_length
+ ) * len(batch)
+
+ start_index = end_index
+
+ first_past_kvs = batches[0].past_key_values
+ _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
+
+ padded_past_values_shape = (
+ total_batch_size,
+ num_heads,
+ max_input_length - 1,
+ head_dim,
+ )
+
+ if batches[0].keys_head_dim_last:
+ padded_past_keys_shape = padded_past_values_shape
+ else:
+ # seq_length is last for BLOOM
+ padded_past_keys_shape = (
+ total_batch_size,
+ num_heads,
+ head_dim,
+ max_input_length - 1,
+ )
+
+ # Iterate over attention layers
+ # Concatenate past key values layer by layer to allow incremental garbage collection
+ for j in range(len(first_past_kvs)):
+ padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
+ start_index = 0
+ for batch in batches:
+ past_keys = batch.past_key_values[j][0]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][0] = None
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the keys to remove the padding from previous batches
+ past_seq_len = batch.max_input_length - 1
+ if batch.keys_head_dim_last:
+ padded_past_keys[
+ start_index:end_index, :, -past_seq_len:, :
+ ] = past_keys[:, :, -past_seq_len:, :]
+ else:
+ # BLOOM case
+ padded_past_keys[
+ start_index:end_index, :, :, -past_seq_len:
+ ] = past_keys[:, :, :, -past_seq_len:]
+ del past_keys
+
+ start_index = end_index
+
+ padded_past_values = first_past_kvs[j][1].new_zeros(
+ padded_past_values_shape
+ )
+ start_index = 0
+ for batch in batches:
+ past_values = batch.past_key_values[j][1]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][1] = None
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the past values to remove the padding from previous batches
+ past_seq_len = batch.max_input_length - 1
+ padded_past_values[
+ start_index:end_index, :, -past_seq_len:, :
+ ] = past_values[:, :, -past_seq_len:, :]
+ del past_values
+
+ # Update values
+ start_index = end_index
+
+ past_key_values.append([padded_past_keys, padded_past_values])
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_attention_mask=image_attention_mask,
+ past_key_values=past_key_values,
+ all_input_ids=all_input_ids,
+ input_lengths=input_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ max_input_length=max_input_length,
+ padding_right_offset=padding_right_offset,
+ keys_head_dim_last=batches[0].keys_head_dim_last,
+ max_tokens=max_tokens,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+
+class IdeficsCausalLM(Model):
+ def __init__(
+ self,
+ model_id: str,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ ):
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ dtype = torch.float16 if dtype is None else dtype
+ else:
+ if quantize:
+ raise ValueError("quantization is not available on CPU")
+
+ device = torch.device("cpu")
+ dtype = torch.float32
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ self.processor = AutoProcessor.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ model = IdeficsForVisionText2Text.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=dtype,
+ device_map="auto"
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1
+ else None,
+ load_in_8bit=quantize == "bitsandbytes",
+ trust_remote_code=trust_remote_code,
+ )
+ if torch.cuda.is_available() and torch.cuda.device_count() == 1:
+ model = model.cuda()
+
+ if tokenizer.pad_token_id is None:
+ if model.config.pad_token_id is not None:
+ tokenizer.pad_token_id = model.config.pad_token_id
+ elif model.config.eos_token_id is not None:
+ tokenizer.pad_token_id = model.config.eos_token_id
+ elif tokenizer.eos_token_id is not None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ else:
+ tokenizer.add_special_tokens({"pad_token": ""})
+
+ super(IdeficsCausalLM, self).__init__(
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ )
+
+ @property
+ def batch_type(self) -> Type[IdeficsCausalLMBatch]:
+ return IdeficsCausalLMBatch
+
+ def decode(self, generated_ids: List[int]) -> str:
+ return self.tokenizer.decode(
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ pixel_values,
+ image_attention_mask,
+ past_key_values: Optional = None,
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
+ # Model Forward
+ kwargs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "image_attention_mask": image_attention_mask,
+ "past_key_values": past_key_values,
+ "use_cache": True,
+ "return_dict": True,
+ }
+ if self.has_position_ids:
+ kwargs["position_ids"] = position_ids
+
+ outputs = self.model.forward(**kwargs)
+ return outputs.logits, outputs.past_key_values
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batch: IdeficsCausalLMBatch
+ ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
+ from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
+ # slice the attention mask to the correct shape
+ attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
+ if batch.input_ids.size(1) == 1:
+ # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
+ # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
+ # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
+ # token need to attend to the encoder hidden states (i.e. the vision encoder)
+ # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
+ image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
+ else:
+ image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
+
+ logits, past = self.forward(
+ input_ids=batch.input_ids,
+ attention_mask=attention_mask,
+ position_ids=batch.position_ids,
+ pixel_values=batch.pixel_values,
+ image_attention_mask=image_attention_mask,
+ past_key_values=batch.past_key_values,
+ )
+
+ # Results
+ generations: List[Generation] = []
+ stopped = True
+
+ # Zipped iterator
+ iterator = zip(
+ batch.requests,
+ batch.input_lengths,
+ batch.prefix_offsets,
+ batch.read_offsets,
+ logits,
+ batch.next_token_choosers,
+ batch.stopping_criterias,
+ batch.all_input_ids,
+ )
+
+ # For each member of the batch
+ for i, (
+ request,
+ input_length,
+ prefix_offset,
+ read_offset,
+ logits,
+ next_token_chooser,
+ stopping_criteria,
+ all_input_ids,
+ ) in enumerate(iterator):
+ # Select next token
+ next_token_id, logprobs = next_token_chooser(
+ all_input_ids.view(1, -1), logits[-1:, :]
+ )
+
+ # Append next token to all tokens
+ all_input_ids = torch.cat([all_input_ids, next_token_id])
+ new_input_length = input_length + 1
+
+ # Generated token
+ next_token_logprob = logprobs[-1, next_token_id]
+ next_token_id_squeezed = next_token_id.squeeze()
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids[:, 0], prefix_offset, read_offset
+ )
+
+ # Evaluate stopping criteria
+ stop, reason = stopping_criteria(
+ next_token_id_squeezed,
+ next_token_text,
+ )
+
+ if not stop:
+ stopped = False
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if i % self.world_size == self.rank:
+ if stop:
+ # Decode generated tokens
+ output_text = self.decode(
+ all_input_ids[-stopping_criteria.current_tokens :, 0]
+ )
+ # Get seed
+ if isinstance(next_token_chooser.choice, Sampling):
+ seed = next_token_chooser.choice.seed
+ else:
+ seed = None
+
+ generated_text = GeneratedText(
+ output_text, stopping_criteria.current_tokens, reason, seed
+ )
+ else:
+ generated_text = None
+
+ # Prefill
+ if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
+ # Remove generated token to only have prefill and add nan for first prompt token
+ prefill_logprobs = [float("nan")] + torch.log_softmax(
+ logits, -1
+ ).gather(1, all_input_ids[1:]).squeeze(1)[
+ -new_input_length:-1
+ ].tolist()
+ prefill_token_ids = all_input_ids[-new_input_length:-1]
+ prefill_texts = self.tokenizer.batch_decode(
+ prefill_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ prefill_tokens = PrefillTokens(
+ prefill_token_ids, prefill_logprobs, prefill_texts
+ )
+ else:
+ prefill_tokens = None
+
+ generation = Generation(
+ request.id,
+ prefill_tokens,
+ next_token_id_squeezed,
+ next_token_logprob,
+ next_token_text,
+ next_token_id_squeezed.item() in self.all_special_ids,
+ generated_text,
+ )
+
+ generations.append(generation)
+
+ # Update values
+ batch.input_ids[i, 0] = next_token_id
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
+ batch.all_input_ids[i] = all_input_ids
+ batch.input_lengths[i] = new_input_length
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
+ batch.max_input_length = max(batch.max_input_length, new_input_length)
+
+ # We finished all generations in the batch; there is no next batch
+ if stopped:
+ return generations, None
+
+ # Slice unused values from prefill
+ batch.input_ids = batch.input_ids[:, :1]
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
+
+ # Update attention_mask as we added a new token to input_ids
+ batch.attention_mask[:, -batch.padding_right_offset] = 1
+ batch.image_attention_mask[:, -batch.padding_right_offset, :] = batch.image_attention_mask[:, -(batch.padding_right_offset+1), :]
+ # Decrease right offset
+ batch.padding_right_offset -= 1
+
+ # Update position_ids
+ batch.position_ids = batch.position_ids[:, -1:] + 1
+
+ # Update past key values
+ batch.past_key_values = past
+
+ from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
+ return generations, batch
diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py
index 1cedc151..e7ed8875 100644
--- a/server/text_generation_server/server.py
+++ b/server/text_generation_server/server.py
@@ -14,6 +14,7 @@ from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
+from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
@@ -54,9 +55,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
- batch = self.model.batch_type.from_pb(
- request.batch, self.model.tokenizer, self.model.dtype, self.model.device
- )
+ if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
+ batch = self.model.batch_type.from_pb(
+ request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
+ )
+ else:
+ batch = self.model.batch_type.from_pb(
+ request.batch, self.model.tokenizer, self.model.dtype, self.model.device
+ )
max_supported_total_tokens = self.model.warmup(batch)
return generate_pb2.WarmupResponse(
@@ -64,9 +70,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
async def Prefill(self, request, context):
- batch = self.model.batch_type.from_pb(
- request.batch, self.model.tokenizer, self.model.dtype, self.model.device
- )
+ if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
+ batch = self.model.batch_type.from_pb(
+ request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
+ )
+ else:
+ batch = self.model.batch_type.from_pb(
+ request.batch, self.model.tokenizer, self.model.dtype, self.model.device
+ )
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py
index 9688f6e0..745c1d2e 100644
--- a/server/text_generation_server/utils/layers.py
+++ b/server/text_generation_server/utils/layers.py
@@ -51,7 +51,31 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
ln.bias = None
return ln
+@classmethod
+def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ bias = weights.get_tensor(f"{prefix}.bias")
+ with init_empty_weights():
+ conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
+ conv2d.weight = nn.Parameter(weight)
+ conv2d.bias = nn.Parameter(bias)
+ return conv2d
+
+
+@classmethod
+def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ with init_empty_weights():
+ conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
+
+ conv2d.weight = nn.Parameter(weight)
+ conv2d.bias = None
+ return conv2d
+
+
+torch.nn.Conv2d.load = load_conv2d
+torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias