mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: clean up idefics 3 and improve prefix handling
This commit is contained in:
parent
4c8f5cdc35
commit
765ca78014
@ -632,20 +632,24 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
base_model = "" if prefix.endswith("text_model") else ".model"
|
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens"
|
f"{name}.embed_tokens"
|
||||||
if not prefix
|
if not prefix
|
||||||
else f"{prefix}{base_model}.embed_tokens"
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(
|
||||||
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
suffix = "model.embed_tokens"
|
suffix = "model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
@ -656,18 +660,13 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
if embedding_multiplier is not None:
|
if embedding_multiplier is not None:
|
||||||
self.embed_tokens.weight.data *= embedding_multiplier
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
|
||||||
if not prefix:
|
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
|
||||||
head_prefix = suffix
|
|
||||||
elif prefix.endswith("text_model"):
|
|
||||||
head_prefix = suffix
|
|
||||||
else:
|
|
||||||
head_prefix = f"{prefix}.{suffix}"
|
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix=head_prefix,
|
prefix,
|
||||||
weights=weights,
|
weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Used in Granite
|
# Used in Granite
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Idefics2 model."""
|
""" PyTorch Idefics3 model."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
class Idefics2VisionEmbeddings(nn.Module):
|
class Idefics3VisionEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||||
resolution.
|
resolution.
|
||||||
@ -131,7 +131,7 @@ class Idefics2VisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class Idefics2VisionAttention(nn.Module):
|
class Idefics3VisionAttention(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -229,7 +229,7 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
class Idefics2VisionMLP(nn.Module):
|
class Idefics3VisionMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -248,11 +248,11 @@ class Idefics2VisionMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Idefics2EncoderLayer(nn.Module):
|
class Idefics3EncoderLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.self_attn = Idefics2VisionAttention(
|
self.self_attn = Idefics3VisionAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.layer_norm1 = nn.LayerNorm.load(
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
@ -261,7 +261,7 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
self.layer_norm2 = nn.LayerNorm.load(
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||||
)
|
)
|
||||||
self.mlp = Idefics2VisionMLP(
|
self.mlp = Idefics3VisionMLP(
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -288,13 +288,13 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Idefics2Encoder(nn.Module):
|
class Idefics3Encoder(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Idefics2EncoderLayer(
|
Idefics3EncoderLayer(
|
||||||
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
)
|
)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
@ -316,14 +316,14 @@ class Idefics2Encoder(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Idefics2VisionTransformer(nn.Module):
|
class Idefics3VisionTransformer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embeddings = Idefics2VisionEmbeddings(
|
self.embeddings = Idefics3VisionEmbeddings(
|
||||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.encoder = Idefics2Encoder(
|
self.encoder = Idefics3Encoder(
|
||||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm.load(
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
@ -377,317 +377,26 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
return last_hidden_state
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
class Idefics2MLP(nn.Module):
|
class Idefics3SimpleMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.text_config.hidden_act
|
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
|
||||||
self.act = (
|
output_size = config.text_config.hidden_size
|
||||||
ACT2FN[act]
|
proj = nn.Parameter(
|
||||||
if "gelu" not in act
|
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
|
||||||
else lambda x: torch.nn.functional.gelu(
|
requires_grad=False,
|
||||||
x,
|
).to(weights.dtype)
|
||||||
approximate=(
|
self.proj = nn.Linear(input_size, output_size, bias=False)
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
self.proj.weight = proj
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
|
||||||
config,
|
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
|
||||||
weights=weights,
|
|
||||||
dim=0,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}.down_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, x):
|
||||||
start_shape = hidden_states.shape[:-1]
|
return self.proj(x)
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
|
||||||
intermediate_size = gate_up_states.shape[-1] // 2
|
|
||||||
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
|
|
||||||
return self.down_proj(
|
|
||||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
|
||||||
).view(*start_shape, -1)
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2RMSNorm(nn.Module):
|
|
||||||
def __init__(self, prefix, weights, eps):
|
|
||||||
"""
|
|
||||||
Idefics2RMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
|
||||||
)
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2PerceiverAttention(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layer_idx = None
|
|
||||||
self.hidden_size = config.text_config.hidden_size
|
|
||||||
self.num_heads = config.perceiver_config.resampler_n_heads
|
|
||||||
self.head_size = config.perceiver_config.resampler_head_dim
|
|
||||||
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
|
|
||||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
||||||
self.attention_dropout = config.perceiver_config.attention_dropout
|
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
|
||||||
self.num_key_value_heads = (
|
|
||||||
self.num_key_value_heads // weights.process_group.size()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.q_proj = TensorParallelColumnLinear.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}.q_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.kv = TensorParallelColumnLinear.load_multi(
|
|
||||||
config,
|
|
||||||
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
||||||
dim=0,
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
|
||||||
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.is_causal = False
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
context: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = latents.size()
|
|
||||||
kv_seq_len = q_len + context.size()[1]
|
|
||||||
|
|
||||||
hidden_states = torch.concat([context, latents], dim=-2)
|
|
||||||
query_states = self.q_proj(latents)
|
|
||||||
kv = self.kv(hidden_states)
|
|
||||||
key_states, value_states = kv.split(
|
|
||||||
[
|
|
||||||
self.head_size * self.num_key_value_heads,
|
|
||||||
self.head_size * self.num_key_value_heads,
|
|
||||||
],
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_size
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(
|
|
||||||
query_states, key_states.transpose(2, 3)
|
|
||||||
) / math.sqrt(self.head_size)
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(
|
|
||||||
attn_weights, dim=-1, dtype=torch.float32
|
|
||||||
).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2PerceiverLayer(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.text_config.hidden_size
|
|
||||||
self.n_latents = config.perceiver_config.resampler_n_latents
|
|
||||||
self.depth = config.perceiver_config.resampler_depth
|
|
||||||
self.rms_norm_eps = config.text_config.rms_norm_eps
|
|
||||||
|
|
||||||
self.input_latents_norm = Idefics2RMSNorm(
|
|
||||||
prefix=f"{prefix}.input_latents_norm",
|
|
||||||
weights=weights,
|
|
||||||
eps=self.rms_norm_eps,
|
|
||||||
)
|
|
||||||
self.input_context_norm = Idefics2RMSNorm(
|
|
||||||
prefix=f"{prefix}.input_context_norm",
|
|
||||||
weights=weights,
|
|
||||||
eps=self.rms_norm_eps,
|
|
||||||
)
|
|
||||||
self.self_attn = Idefics2PerceiverAttention(
|
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
|
||||||
)
|
|
||||||
self.post_attention_layernorm = Idefics2RMSNorm(
|
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=self.rms_norm_eps,
|
|
||||||
)
|
|
||||||
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
context: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
|
||||||
"""
|
|
||||||
residual = latents
|
|
||||||
|
|
||||||
latents = self.input_latents_norm(latents)
|
|
||||||
context = self.input_context_norm(context)
|
|
||||||
|
|
||||||
latents = self.self_attn(
|
|
||||||
latents=latents,
|
|
||||||
context=context,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
)
|
|
||||||
latents = residual + latents
|
|
||||||
residual = latents
|
|
||||||
|
|
||||||
latents = self.post_attention_layernorm(latents)
|
|
||||||
latents = self.mlp(latents)
|
|
||||||
latents = residual + latents
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2PerceiverResampler(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.text_config.hidden_size
|
|
||||||
self.hidden_act = config.perceiver_config.hidden_act
|
|
||||||
self.n_latents = config.perceiver_config.resampler_n_latents
|
|
||||||
self.depth = config.perceiver_config.resampler_depth
|
|
||||||
self.rms_norm_eps = config.text_config.rms_norm_eps
|
|
||||||
|
|
||||||
# Create Latents for Perceiver
|
|
||||||
self.latents = weights.get_tensor(f"{prefix}.latents")
|
|
||||||
|
|
||||||
# Create Transformer Blocks
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
Idefics2PerceiverLayer(
|
|
||||||
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
|
|
||||||
)
|
|
||||||
for idx in range(self.depth)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.norm = Idefics2RMSNorm(
|
|
||||||
prefix=f"{prefix}.norm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.text_config.rms_norm_eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
context: torch.Tensor,
|
|
||||||
attention_mask,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# seq embed -> bsz seq embed
|
|
||||||
latents = self.latents.unsqueeze(0).expand(
|
|
||||||
(context.shape[0], *self.latents.size())
|
|
||||||
)
|
|
||||||
|
|
||||||
latent_attention_mask = torch.ones(
|
|
||||||
(attention_mask.size(0), latents.size(1)),
|
|
||||||
dtype=attention_mask.dtype,
|
|
||||||
device=attention_mask.device,
|
|
||||||
)
|
|
||||||
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
|
|
||||||
attention_mask = _prepare_4d_attention_mask(
|
|
||||||
attention_mask, latents.dtype, tgt_len=self.n_latents
|
|
||||||
)
|
|
||||||
|
|
||||||
compressed_context = latents
|
|
||||||
for perceiver_layer in self.layers:
|
|
||||||
compressed_context = perceiver_layer(
|
|
||||||
compressed_context,
|
|
||||||
context,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
)
|
|
||||||
compressed_context = self.norm(compressed_context)
|
|
||||||
|
|
||||||
return compressed_context
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2Connector(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.modality_projection = Idefics2MLP(
|
|
||||||
prefix=f"{prefix}.modality_projection", config=config, weights=weights
|
|
||||||
)
|
|
||||||
self.perceiver_resampler = Idefics2PerceiverResampler(
|
|
||||||
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, image_hidden_states, attention_mask):
|
|
||||||
image_hidden_states = self.modality_projection(image_hidden_states)
|
|
||||||
image_hidden_states = self.perceiver_resampler(
|
|
||||||
context=image_hidden_states, attention_mask=attention_mask
|
|
||||||
)
|
|
||||||
return image_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics3Connector(nn.Module):
|
class Idefics3Connector(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.modality_projection = TensorParallelRowLinear.load(
|
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
|
||||||
prefix=f"{prefix}.modality_projection.proj",
|
|
||||||
config=config,
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.scale_factor = config.scale_factor
|
self.scale_factor = config.scale_factor
|
||||||
|
|
||||||
def pixel_shuffle(self, x, scale_factor=2):
|
def pixel_shuffle(self, x, scale_factor=2):
|
||||||
@ -706,8 +415,7 @@ class Idefics3Connector(nn.Module):
|
|||||||
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, image_hidden_states, attention_mask):
|
def forward(self, image_hidden_states):
|
||||||
print(image_hidden_states.device, self.modality_projection.linear.weight.device)
|
|
||||||
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
||||||
image_hidden_states = self.modality_projection(image_hidden_states)
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||||
return image_hidden_states
|
return image_hidden_states
|
||||||
@ -726,7 +434,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
vision_config = config.vision_config
|
vision_config = config.vision_config
|
||||||
self.text_model = load_text_model(
|
self.text_model = load_text_model(
|
||||||
prefix=f"{prefix}.model.text_model" if prefix else "model.text_model",
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
config=config.text_config,
|
config=config.text_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
name="text_model",
|
name="text_model",
|
||||||
@ -735,7 +443,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
# The vision and connector models are not quantized.
|
# The vision and connector models are not quantized.
|
||||||
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
||||||
self.vision_model = Idefics2VisionTransformer(
|
self.vision_model = Idefics3VisionTransformer(
|
||||||
prefix=(
|
prefix=(
|
||||||
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||||
),
|
),
|
||||||
@ -810,7 +518,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
dim=(-1, -2, -3)
|
dim=(-1, -2, -3)
|
||||||
) != nb_values_per_image
|
) != nb_values_per_image
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
# Handle the vision attention mask
|
# Handle the vision attention mask
|
||||||
if pixel_attention_mask is None:
|
if pixel_attention_mask is None:
|
||||||
pixel_attention_mask = torch.ones(
|
pixel_attention_mask = torch.ones(
|
||||||
@ -850,7 +557,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
# Modality projection & resampling
|
# Modality projection & resampling
|
||||||
image_hidden_states = self.connector(
|
image_hidden_states = self.connector(
|
||||||
image_hidden_states,
|
image_hidden_states,
|
||||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
all_states.append(image_hidden_states)
|
all_states.append(image_hidden_states)
|
||||||
@ -877,164 +583,3 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class Idefics2ForConditionalGeneration(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
config.vision_config.quantize = None
|
|
||||||
config.vision_config.speculator = config.speculator
|
|
||||||
config.text_config.quantize = config.quantize
|
|
||||||
config.text_config.speculator = config.speculator
|
|
||||||
|
|
||||||
vision_config = config.vision_config
|
|
||||||
self.text_model = load_text_model(
|
|
||||||
prefix="model" if not prefix else f"{prefix}.model",
|
|
||||||
config=config.text_config,
|
|
||||||
weights=weights,
|
|
||||||
name="text_model",
|
|
||||||
)
|
|
||||||
self.dtype = weights.dtype
|
|
||||||
|
|
||||||
# The vision and connector models are not quantized.
|
|
||||||
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
|
||||||
self.vision_model = Idefics2VisionTransformer(
|
|
||||||
prefix=(
|
|
||||||
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
|
||||||
),
|
|
||||||
config=vision_config,
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
config.quantize = None
|
|
||||||
self.connector = Idefics2Connector(
|
|
||||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
|
||||||
config=config,
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
|
||||||
self.image_token_id = config.image_token_id
|
|
||||||
self.pad_token_id = (
|
|
||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
|
||||||
)
|
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
image_features: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
|
||||||
# mask = input_ids == self.config.image_token_index
|
|
||||||
mask = input_ids == self.config.image_token_id
|
|
||||||
# Let's pray we have enabled enough slots !
|
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
# Unused here
|
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
):
|
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
|
||||||
all_states = []
|
|
||||||
all_pixel_values = pixel_values
|
|
||||||
all_pixel_mask = pixel_attention_mask
|
|
||||||
for i in range(batch_size):
|
|
||||||
pixel_values = all_pixel_values.to(
|
|
||||||
dtype=self.dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
pixel_values = pixel_values[i : i + 1]
|
|
||||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(
|
|
||||||
dim=(-1, -2, -3)
|
|
||||||
) != nb_values_per_image
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=(
|
|
||||||
pixel_values.size(0),
|
|
||||||
pixel_values.size(2),
|
|
||||||
pixel_values.size(3),
|
|
||||||
),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask/pP p
|
|
||||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[
|
|
||||||
real_images_inds
|
|
||||||
].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
|
||||||
dimension=1, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(
|
|
||||||
dimension=2, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(
|
|
||||||
image_hidden_states,
|
|
||||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
|
||||||
)
|
|
||||||
all_states.append(image_hidden_states)
|
|
||||||
image_hidden_states = torch.stack(all_states, dim=0)
|
|
||||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
|
||||||
# that simply don't exist
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
input_ids, inputs_embeds, image_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
|
||||||
seqlen=seqlen,
|
|
||||||
max_s=max_s,
|
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
|
||||||
adapter_data=adapter_data,
|
|
||||||
)
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
|
||||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
return FlashLlamaForCausalLM(prefix, config, weights, name=name)
|
||||||
elif config.model_type == "mistral":
|
elif config.model_type == "mistral":
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
|
@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
|
|||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
@ -29,25 +30,32 @@ IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
|||||||
|
|
||||||
|
|
||||||
def get_image_prompt_string(
|
def get_image_prompt_string(
|
||||||
rows=0,
|
*,
|
||||||
cols=0,
|
image_seq_len,
|
||||||
seq_len=1,
|
image_rows,
|
||||||
fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
|
image_cols,
|
||||||
img_token=IDEFICS3_IMAGE_TOKEN,
|
fake_token_around_image,
|
||||||
global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
image_token,
|
||||||
|
global_img_token,
|
||||||
):
|
):
|
||||||
tokens = img_token * seq_len
|
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||||
end_token = f"{fake_token}{global_token}{tokens}{fake_token}"
|
text_split_images = ""
|
||||||
|
for n_h in range(image_rows):
|
||||||
|
for n_w in range(image_cols):
|
||||||
|
text_split_images += (
|
||||||
|
f"{fake_token_around_image}"
|
||||||
|
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
)
|
||||||
|
text_split_images += "\n"
|
||||||
|
|
||||||
if rows == 0 or cols == 0:
|
text_split_images += (
|
||||||
return end_token
|
f"\n{fake_token_around_image}"
|
||||||
|
+ f"{global_img_token}"
|
||||||
grid = "\n".join(
|
+ f"{image_token}" * image_seq_len
|
||||||
"".join(f"{fake_token}<row_{i+1}_col_{j+1}>{tokens}" for j in range(cols))
|
+ f"{fake_token_around_image}"
|
||||||
for i in range(rows)
|
|
||||||
)
|
)
|
||||||
|
return text_split_images
|
||||||
return f"{grid}\n\n{end_token}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
@ -89,18 +97,17 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
/ (config.scale_factor**2)
|
/ (config.scale_factor**2)
|
||||||
)
|
)
|
||||||
image_str = get_image_prompt_string(
|
image_str = get_image_prompt_string(
|
||||||
rows=n_rows,
|
image_seq_len=image_seq_len,
|
||||||
cols=n_cols,
|
image_rows=n_rows,
|
||||||
seq_len=image_seq_len,
|
image_cols=n_cols,
|
||||||
fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
|
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||||
img_token=IDEFICS3_IMAGE_TOKEN,
|
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
)
|
)
|
||||||
return image_str
|
return image_str
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
@ -238,9 +245,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
if images:
|
if images:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
match processor.image_processor_class:
|
if (
|
||||||
case "Idefics3ImageProcessor":
|
hasattr(processor, "image_processor_class")
|
||||||
kwargs["return_row_col_info"] = True
|
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||||
|
):
|
||||||
|
kwargs["return_row_col_info"] = True
|
||||||
|
|
||||||
image_inputs = processor.image_processor(
|
image_inputs = processor.image_processor(
|
||||||
images, return_tensors="pt", **kwargs
|
images, return_tensors="pt", **kwargs
|
||||||
|
Loading…
Reference in New Issue
Block a user