mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Temporary dump of idefics2.
This commit is contained in:
parent
9f3ce55ce2
commit
f68ccfd023
@ -114,7 +114,7 @@ impl Client {
|
|||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(";
|
inputs.push_str("");
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
|
@ -92,6 +92,7 @@ pub enum Config {
|
|||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics,
|
||||||
|
Idefics2,
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
Santacoder,
|
Santacoder,
|
||||||
|
@ -540,7 +540,30 @@ fn prepare_input(
|
|||||||
inputs = modified_inputs;
|
inputs = modified_inputs;
|
||||||
tokenizer_query
|
tokenizer_query
|
||||||
}
|
}
|
||||||
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
|
Some(Config::Idefics | Config::Idefics2) => {
|
||||||
|
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||||
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
let mut start = 0;
|
||||||
|
for chunk in RE.find_iter(&inputs) {
|
||||||
|
let chunk_start = chunk.start();
|
||||||
|
let chunk_end = chunk.end();
|
||||||
|
if chunk_start != start {
|
||||||
|
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
|
}
|
||||||
|
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
|
let slots = 1;
|
||||||
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
|
modified_inputs.push_str(&image_uri);
|
||||||
|
start = chunk_end;
|
||||||
|
}
|
||||||
|
if start != inputs.len() - 1 {
|
||||||
|
modified_inputs.push_str(&inputs[start..]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
|
}
|
||||||
|
inputs = modified_inputs;
|
||||||
|
tokenizer_query
|
||||||
|
}
|
||||||
_ => inputs.clone(),
|
_ => inputs.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
from text_generation_server.models.llava_next import LlavaNext
|
from text_generation_server.models.llava_next import LlavaNext
|
||||||
|
from text_generation_server.models.idefics2 import Idefics2
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
@ -579,6 +580,18 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
if model_type == "idefics2":
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return Idefics2(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == "llava_next":
|
if model_type == "llava_next":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMistralForCausalLM(torch.nn.Module):
|
class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
f"{name}.embed_tokens"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = MistralModel(
|
self.model = MistralModel(
|
||||||
prefix="model" if not prefix else f"{prefix}.model",
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
# TODO dirty hack for idefics2.
|
||||||
|
prefix=(
|
||||||
|
"lm_head" if not prefix or name is not "model" else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
803
server/text_generation_server/models/custom_modeling/idefics2.py
Normal file
803
server/text_generation_server/models/custom_modeling/idefics2.py
Normal file
@ -0,0 +1,803 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Idefics2 model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||||
|
resolution.
|
||||||
|
|
||||||
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||||||
|
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||||||
|
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||||||
|
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.patch_embedding.bias = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches_per_side = self.image_size // self.patch_size
|
||||||
|
self.num_patches = self.num_patches_per_side**2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||||
|
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
max_nb_patches_h, max_nb_patches_w = (
|
||||||
|
max_im_h // self.patch_size,
|
||||||
|
max_im_w // self.patch_size,
|
||||||
|
)
|
||||||
|
boundaries = torch.arange(
|
||||||
|
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
||||||
|
)
|
||||||
|
position_ids = torch.full(
|
||||||
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||||
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||||
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||||
|
|
||||||
|
bucket_coords_h = torch.bucketize(
|
||||||
|
fractional_coords_h, boundaries, right=True
|
||||||
|
)
|
||||||
|
bucket_coords_w = torch.bucketize(
|
||||||
|
fractional_coords_w, boundaries, right=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_ids = (
|
||||||
|
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
||||||
|
).flatten()
|
||||||
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||||
|
|
||||||
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||||
|
embeddings = embeddings + self.position_embedding(position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = self.embed_dim // self.num_heads
|
||||||
|
if self.head_size * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_size**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
k_v_seq_len = key_states.shape[-2]
|
||||||
|
attn_weights = (
|
||||||
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = Idefics2VisionAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = Idefics2VisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2Encoder(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics2EncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = Idefics2VisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.encoder = Idefics2Encoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values,
|
||||||
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size = pixel_values.size(0)
|
||||||
|
if patch_attention_mask is None:
|
||||||
|
patch_size = self.config.patch_size
|
||||||
|
patch_attention_mask = torch.ones(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
pixel_values.size(2) // patch_size,
|
||||||
|
pixel_values.size(3) // patch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
patch_attention_mask = patch_attention_mask.to(
|
||||||
|
dtype=torch.bool, device=pixel_values.device
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(
|
||||||
|
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||||
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||||
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||||
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||||
|
if not torch.any(~patch_attention_mask):
|
||||||
|
patch_attention_mask = None
|
||||||
|
else:
|
||||||
|
patch_attention_mask = _prepare_4d_attention_mask(
|
||||||
|
patch_attention_mask, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.text_config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.text_config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2RMSNorm(nn.Module):
|
||||||
|
def __init__(self, prefix, weights, eps):
|
||||||
|
"""
|
||||||
|
Idefics2RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_idx = None
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.num_heads = config.perceiver_config.resampler_n_heads
|
||||||
|
self.head_size = config.perceiver_config.resampler_head_dim
|
||||||
|
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.attention_dropout = config.perceiver_config.attention_dropout
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.text_config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = latents.size()
|
||||||
|
kv_seq_len = q_len + context.size()[1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
hidden_states = torch.concat([context, latents], dim=-2)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_size)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||||
|
self.depth = config.perceiver_config.resampler_depth
|
||||||
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||||
|
|
||||||
|
self.input_latents_norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.input_latents_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.input_context_norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.input_context_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.self_attn = Idefics2PerceiverAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
|
"""
|
||||||
|
residual = latents
|
||||||
|
|
||||||
|
latents = self.input_latents_norm(latents)
|
||||||
|
context = self.input_context_norm(context)
|
||||||
|
|
||||||
|
latents = self.self_attn(
|
||||||
|
latents=latents,
|
||||||
|
context=context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
latents = residual + latents
|
||||||
|
residual = latents
|
||||||
|
|
||||||
|
latents = self.post_attention_layernorm(latents)
|
||||||
|
latents = self.mlp(latents)
|
||||||
|
latents = residual + latents
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverResampler(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.hidden_act = config.perceiver_config.hidden_act
|
||||||
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||||
|
self.depth = config.perceiver_config.resampler_depth
|
||||||
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||||
|
|
||||||
|
# Create Latents for Perceiver
|
||||||
|
self.latents = weights.get_tensor(f"{prefix}.latents")
|
||||||
|
|
||||||
|
# Create Transformer Blocks
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics2PerceiverLayer(
|
||||||
|
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for idx in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.text_config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# seq embed -> bsz seq embed
|
||||||
|
latents = self.latents.unsqueeze(0).expand(
|
||||||
|
(context.shape[0], *self.latents.size())
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_attention_mask = torch.ones(
|
||||||
|
(attention_mask.size(0), latents.size(1)),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
|
)
|
||||||
|
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
|
||||||
|
attention_mask = _prepare_4d_attention_mask(
|
||||||
|
attention_mask, latents.dtype, tgt_len=self.n_latents
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_context = latents
|
||||||
|
for perceiver_layer in self.layers:
|
||||||
|
compressed_context = perceiver_layer(
|
||||||
|
compressed_context,
|
||||||
|
context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
compressed_context = self.norm(compressed_context)
|
||||||
|
|
||||||
|
return compressed_context
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2Connector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.modality_projection = Idefics2MLP(
|
||||||
|
prefix=f"{prefix}.modality_projection", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.perceiver_resampler = Idefics2PerceiverResampler(
|
||||||
|
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image_hidden_states, attention_mask):
|
||||||
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||||
|
image_hidden_states = self.perceiver_resampler(
|
||||||
|
context=image_hidden_states, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
config.vision_config.use_medusa = config.use_medusa
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.use_medusa = config.use_medusa
|
||||||
|
|
||||||
|
vision_config = config.vision_config
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
name="text_model",
|
||||||
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
self.vision_model = Idefics2VisionTransformer(
|
||||||
|
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
||||||
|
config=vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.connector = Idefics2Connector(
|
||||||
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
mask = input_ids == self.config.image_token_index
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values.view(
|
||||||
|
batch_size * num_images, *pixel_values.shape[2:]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||||
|
)
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, image_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -23,6 +23,10 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def load_vision_model(prefix, config, weights):
|
|
||||||
if config.model_type == "clip_vision_model":
|
|
||||||
from text_generation_server.models.custom_modeling.clip import (
|
|
||||||
CLIPVisionTransformer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CLIPVisionTransformer(
|
|
||||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_text_model(prefix, config, weights):
|
|
||||||
if config.model_type == "llama":
|
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
||||||
FlashLlamaForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
|
||||||
elif config.model_type == "mistral":
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
||||||
FlashMistralForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashMistralForCausalLM(prefix, config, weights)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextForConditionalGeneration(nn.Module):
|
class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
28
server/text_generation_server/models/custom_modeling/vlm.py
Normal file
28
server/text_generation_server/models/custom_modeling/vlm.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
def load_text_model(prefix, config, weights, name=None):
|
||||||
|
if config.model_type == "llama":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||||
|
elif config.model_type == "mistral":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vision_model(prefix, config, weights):
|
||||||
|
if config.model_type == "clip_vision_model":
|
||||||
|
from text_generation_server.models.custom_modeling.clip import (
|
||||||
|
CLIPVisionTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CLIPVisionTransformer(
|
||||||
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
@ -1,31 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
from text_generation_server.models.custom_modeling.idefics2_config import IdeficsConfig
|
Idefics2ForConditionalGeneration,
|
||||||
from text_generation_server.models.custom_modeling.idefics_processing import (
|
|
||||||
IdeficsProcessor,
|
|
||||||
)
|
|
||||||
from transformers import LlamaTokenizerFast
|
|
||||||
from text_generation_server.models.custom_modeling.idefics2_modeling import (
|
|
||||||
Idefics2ForVisionText2Text,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
|
||||||
class IDEFICS2Sharded(IdeficsCausalLM):
|
|
||||||
|
class Idefics2(VlmCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -35,59 +22,25 @@ class IDEFICS2Sharded(IdeficsCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
if torch.cuda.is_available():
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
# 9b seems to work correctly enough in float16, but 80b seems
|
|
||||||
# to be really saturating for f16.
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
self.device, self.dtype = device, dtype
|
|
||||||
|
|
||||||
config = IdeficsConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
super().__init__(
|
||||||
config.use_medusa = use_medusa
|
model_cls=Idefics2ForConditionalGeneration,
|
||||||
config.vision_config.quantize = quantize
|
model_id=model_id,
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
quantize=quantize,
|
||||||
truncation_side="left",
|
use_medusa=use_medusa,
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
self.processor = IdeficsProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
process_group=self.process_group,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = IdeficsForVisionText2Text(config, weights)
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
torch.distributed.barrier(group=self.process_group)
|
len(model.text_model.model.layers),
|
||||||
super(IdeficsCausalLM, self).__init__(
|
model.text_model.model.num_key_value_heads,
|
||||||
model=model,
|
model.text_model.model.head_size,
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.text_model, "max_past", None)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
|
len(model.language_model.model.layers),
|
||||||
|
model.language_model.model.num_key_value_heads,
|
||||||
|
model.language_model.model.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.language_model, "max_past", None)
|
||||||
|
@ -147,8 +147,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
"Cannot process input image not starting with data:"
|
"Cannot process input image not starting with data:"
|
||||||
)
|
)
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
height, width = image_input["image_sizes"][0]
|
# import ipdb;ipdb.set_trace()
|
||||||
num_features = get_number_of_features(height, width, config)
|
# height, width = image_input["image_sizes"][0]
|
||||||
|
# num_features = get_number_of_features(height, width, config)
|
||||||
|
num_features = 1
|
||||||
full_text += "<image>" * num_features
|
full_text += "<image>" * num_features
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
@ -165,7 +167,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
"pixel_values": torch.cat(
|
"pixel_values": torch.cat(
|
||||||
[img["pixel_values"] for img in image_inputs], dim=0
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||||||
),
|
),
|
||||||
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
"pixel_attention_mask": torch.cat(
|
||||||
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||||
|
),
|
||||||
|
# "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
@ -187,10 +192,14 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
# batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||||
else:
|
else:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.image_sizes = None
|
batch.pixel_attention_mask = None
|
||||||
|
# batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@ -199,16 +208,6 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return VlmCausalLMBatch
|
return VlmCausalLMBatch
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.language_model.model.layers),
|
|
||||||
model.language_model.model.num_key_value_heads,
|
|
||||||
model.language_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: VlmCausalLMBatch
|
self, batch: VlmCausalLMBatch
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -294,14 +293,17 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
image_sizes=batch.image_sizes,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
|
# image_sizes=batch.image_sizes,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
if batch.image_sizes is not None:
|
if batch.pixel_attention_mask is not None:
|
||||||
batch.image_sizes = None
|
batch.pixel_attention_mask = None
|
||||||
|
# if batch.image_sizes is not None:
|
||||||
|
# batch.image_sizes = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
Loading…
Reference in New Issue
Block a user