mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Idefics2.
This commit is contained in:
parent
9f44af470c
commit
9f3ce55ce2
@ -0,0 +1,457 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, List, Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionFlashAttention2(nn.Module):
|
||||||
|
"""
|
||||||
|
Idefics2 flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays
|
||||||
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_causal = False # Hack to make sure we don't use a causal mask
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (Idefics2RMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
|
||||||
|
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||||
|
def _flash_attention_forward(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
query_length,
|
||||||
|
dropout=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||||
|
position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
dropout (`int`, *optional*):
|
||||||
|
Attention dropout
|
||||||
|
softmax_scale (`float`, *optional*):
|
||||||
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||||
|
"""
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
indices_q,
|
||||||
|
cu_seq_lens,
|
||||||
|
max_seq_lens,
|
||||||
|
) = self._upad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = pad_input(
|
||||||
|
attn_output_unpad, indices_q, batch_size, query_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||||
|
def _upad_input(
|
||||||
|
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
||||||
|
):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(
|
||||||
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
||||||
|
indices_k,
|
||||||
|
)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
||||||
|
indices_k,
|
||||||
|
)
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(
|
||||||
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
||||||
|
indices_k,
|
||||||
|
)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
||||||
|
query_layer, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
IDEFICS_VISION_ATTENTION_CLASSES = {
|
||||||
|
"eager": Idefics2VisionAttention,
|
||||||
|
"flash_attention_2": Idefics2VisionFlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision
|
||||||
|
class Idefics2VisionMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead with Siglip->Idefics2
|
||||||
|
class Idefics2MultiheadAttentionPoolingHead(nn.Module):
|
||||||
|
"""Multihead Attention Pooling."""
|
||||||
|
|
||||||
|
def __init__(self, config: Idefics2VisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||||
|
self.attention = torch.nn.MultiheadAttention(
|
||||||
|
config.hidden_size, config.num_attention_heads, batch_first=True
|
||||||
|
)
|
||||||
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.mlp = Idefics2MLP(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
batch_size = hidden_state.shape[0]
|
||||||
|
probe = self.probe.repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
|
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
||||||
|
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.layernorm(hidden_state)
|
||||||
|
hidden_state = residual + self.mlp(hidden_state)
|
||||||
|
|
||||||
|
return hidden_state[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: Idefics2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = Idefics2VisionFlashAttention2(config)
|
||||||
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
self.mlp = Idefics2VisionMLP(config)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.FloatTensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`):
|
||||||
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||||
|
attention_mask (`torch.FloatTensor`):
|
||||||
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states, attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics2
|
||||||
|
class Idefics2Encoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||||
|
[`Idefics2EncoderLayer`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Idefics2Config
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Idefics2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[Idefics2EncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_states = () if output_hidden_states else None
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
encoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, encoder_states, all_attentions]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
)
|
93
server/text_generation_server/models/idefics2.py
Normal file
93
server/text_generation_server/models/idefics2.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoConfig,
|
||||||
|
AutoProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2_config import IdeficsConfig
|
||||||
|
from text_generation_server.models.custom_modeling.idefics_processing import (
|
||||||
|
IdeficsProcessor,
|
||||||
|
)
|
||||||
|
from transformers import LlamaTokenizerFast
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2_modeling import (
|
||||||
|
Idefics2ForVisionText2Text,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IDEFICS2Sharded(IdeficsCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
# 9b seems to work correctly enough in float16, but 80b seems
|
||||||
|
# to be really saturating for f16.
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
self.device, self.dtype = device, dtype
|
||||||
|
|
||||||
|
config = IdeficsConfig.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
config.vision_config.quantize = quantize
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
self.processor = IdeficsProcessor.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = IdeficsForVisionText2Text(config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(IdeficsCausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user