mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
* feat: add ruff and resolve issue * fix: update client exports and adjust after rebase * fix: adjust syntax to avoid circular import * fix: adjust client ruff settings * fix: lint and refactor import check and avoid model enum as global names * fix: improve fbgemm_gpu check and lints * fix: update lints * fix: prefer comparing model enum over str * fix: adjust lints and ignore specific rules * fix: avoid unneeded quantize check
530 lines
21 KiB
Python
530 lines
21 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
|
|
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
|
from transformers.utils import (
|
|
ModelOutput,
|
|
logging,
|
|
)
|
|
from text_generation_server.layers import (
|
|
TensorParallelColumnLinear,
|
|
TensorParallelRowLinear,
|
|
TensorParallelEmbedding,
|
|
)
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class IdeficsVisionModelOutput(ModelOutput):
|
|
"""
|
|
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
|
|
|
Args:
|
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
|
The image embeddings obtained by applying the projection layer to the pooler_output.
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
"""
|
|
|
|
image_embeds: Optional[torch.FloatTensor] = None
|
|
last_hidden_state: torch.FloatTensor = None
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
|
|
class IdeficsVisionEmbeddings(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.class_embedding = nn.Parameter(
|
|
weights.get_tensor(f"{prefix}.class_embedding")
|
|
)
|
|
|
|
self.patch_embedding = nn.Conv2d.load_no_bias(
|
|
prefix=f"{prefix}.patch_embedding",
|
|
weights=weights,
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches + 1
|
|
self.position_embedding = TensorParallelEmbedding(
|
|
prefix="model.vision_model.embeddings.position_embedding", weights=weights
|
|
)
|
|
self.position_ids = (
|
|
torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
|
|
)
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
batch_size = pixel_values.shape[0]
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
patch_embeds = self.patch_embedding(
|
|
pixel_values.to(dtype=target_dtype)
|
|
) # shape = [*, width, grid, grid]
|
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
|
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
return embeddings
|
|
|
|
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision
|
|
class IdeficsVisionAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.scale = self.head_dim**-0.5
|
|
self.dropout = config.attention_dropout
|
|
|
|
if self.num_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.num_heads = self.num_heads // weights.process_group.size()
|
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
|
|
|
self.k_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
|
|
)
|
|
self.v_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
|
|
)
|
|
self.q_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
|
|
)
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
|
)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return (
|
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
bsz, tgt_len, _ = hidden_states.size()
|
|
|
|
# get query proj
|
|
query_states = self.q_proj(hidden_states) * self.scale
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
|
key_states = key_states.view(*proj_shape)
|
|
value_states = value_states.view(*proj_shape)
|
|
|
|
src_len = key_states.size(1)
|
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
# apply the causal_attention_mask first
|
|
if causal_attention_mask is not None:
|
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
|
f" {causal_attention_mask.size()}"
|
|
)
|
|
attn_weights = (
|
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
+ causal_attention_mask
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_weights = (
|
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
+ attention_mask
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
if output_attentions:
|
|
# this operation is a bit akward, but it's required to
|
|
# make sure that attn_weights keeps its gradient.
|
|
# In order to do so, attn_weights have to reshaped
|
|
# twice and have to be reused in the following
|
|
attn_weights_reshaped = attn_weights.view(
|
|
bsz, self.num_heads, tgt_len, src_len
|
|
)
|
|
attn_weights = attn_weights_reshaped.view(
|
|
bsz * self.num_heads, tgt_len, src_len
|
|
)
|
|
else:
|
|
attn_weights_reshaped = None
|
|
|
|
attn_probs = nn.functional.dropout(
|
|
attn_weights, p=self.dropout, training=self.training
|
|
)
|
|
|
|
attn_output = torch.bmm(attn_probs, value_states)
|
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights_reshaped
|
|
|
|
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
|
|
class IdeficsVisionMLP(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.fc1", weights=weights, bias=True
|
|
)
|
|
self.fc2 = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.fc2", weights=weights, bias=True
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision
|
|
class IdeficsVisionEncoderLayer(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = IdeficsVisionAttention(
|
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
|
)
|
|
self.layer_norm1 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
|
)
|
|
self.mlp = IdeficsVisionMLP(
|
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
|
)
|
|
self.layer_norm2 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
causal_attention_mask: torch.Tensor,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.FloatTensor]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
`(config.encoder_attention_heads,)`.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states, attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
causal_attention_mask=causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision
|
|
class IdeficsVisionEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
[`IdeficsVisionEncoderLayer`].
|
|
|
|
Args:
|
|
config: IdeficsVisionConfig
|
|
"""
|
|
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
IdeficsVisionEncoderLayer(
|
|
prefix=f"{prefix}.encoder.layers.{layer_id}",
|
|
config=config,
|
|
weights=weights,
|
|
)
|
|
for layer_id in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
# self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutput]:
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
|
|
hidden_states = inputs_embeds
|
|
for idx, encoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
# if self.gradient_checkpointing and self.training:
|
|
|
|
# def create_custom_forward(module):
|
|
# def custom_forward(*inputs):
|
|
# return module(*inputs, output_attentions)
|
|
|
|
# return custom_forward
|
|
|
|
# layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
# create_custom_forward(encoder_layer),
|
|
# hidden_states,
|
|
# attention_mask,
|
|
# causal_attention_mask,
|
|
# )
|
|
# else:
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [hidden_states, encoder_states, all_attentions]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=encoder_states,
|
|
attentions=all_attentions,
|
|
)
|
|
|
|
|
|
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
|
|
class IdeficsVisionTransformer(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.embeddings = IdeficsVisionEmbeddings(
|
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
|
)
|
|
self.pre_layrnorm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
|
)
|
|
self.encoder = IdeficsVisionEncoder(
|
|
prefix=prefix, config=config, weights=weights
|
|
)
|
|
self.post_layernorm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.post_layernorm",
|
|
weights=weights,
|
|
eps=config.layer_norm_eps,
|
|
)
|
|
|
|
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
r"""
|
|
Returns:
|
|
|
|
"""
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
hidden_states = self.embeddings(pixel_values)
|
|
hidden_states = self.pre_layrnorm(hidden_states)
|
|
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
last_hidden_state = encoder_outputs[0]
|
|
pooled_output = last_hidden_state[:, 0, :]
|
|
pooled_output = self.post_layernorm(pooled_output)
|
|
|
|
if not return_dict:
|
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=last_hidden_state,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|