diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py
index d381d4c6..7e740e5f 100644
--- a/backends/gaudi/server/text_generation_server/layers/rotary.py
+++ b/backends/gaudi/server/text_generation_server/layers/rotary.py
@@ -36,7 +36,9 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
- self.max_position_embeddings = max_position_embeddings
+ self._update_cos_sin_cache(
+ torch.float32, inv_freq.device, max_position_embeddings
+ )
def forward(
self,
@@ -268,9 +270,7 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor):
- self._update_cos_sin_cache(
- torch.float32, position_ids.device, seqlen=self.max_position_embeddings
- )
+
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
@@ -298,6 +298,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, short_inv_freq.device, max_position_embeddings
+ )
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@@ -351,6 +354,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, short_inv_freq.device, max_position_embeddings
+ )
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
@@ -592,9 +598,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
- self._update_cos_sin_cache(
- torch.float32, position_ids.device, seqlen=self.max_position_embeddings
- )
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
index 7a32a85c..367c26c9 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
@@ -160,18 +160,14 @@ class FlashCohereAttention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
- self.rotary_emb = CohereRotary.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
@@ -325,11 +321,14 @@ class CohereMLP(nn.Module):
class FlashCohereLayer(nn.Module):
- def __init__(self, prefix: str, layer_id, config, weights):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention(
- prefix=f"{prefix}.self_attn", config=config, weights=weights
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
)
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@@ -385,6 +384,12 @@ class FlashCohereModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
+ rotary_emb = CohereRotary.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
FlashCohereLayer(
@@ -392,6 +397,7 @@ class FlashCohereModel(torch.nn.Module):
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
index 42af7798..c097f71e 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
@@ -263,6 +263,7 @@ class DbrxAttention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.clip_qkv = config.attn_config.clip_qkv
@@ -270,12 +271,7 @@ class DbrxAttention(torch.nn.Module):
self.hidden_size = config.d_model
self.head_size = self.hidden_size // self.num_heads
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.attn_config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
@@ -370,13 +366,17 @@ class DbrxNormAttentionNorm(nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.norm_1 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
)
self.self_attn = DbrxAttention(
- prefix=f"{prefix}.attn", config=config, weights=weights
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
)
self.norm_2 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2",
@@ -601,12 +601,15 @@ class DenseMoE(nn.Module):
class DbrxLayer(nn.Module):
- def __init__(self, prefix: str, layer_id, config, weights):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm(
- prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
+ prefix=f"{prefix}.norm_attn_norm",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
)
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
@@ -649,6 +652,12 @@ class DbrxModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.wte", weights=weights
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.d_model // config.n_heads,
+ base=config.attn_config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
@@ -657,6 +666,7 @@ class DbrxModel(torch.nn.Module):
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.n_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
index 8e9002a2..08b7d99d 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
@@ -156,6 +156,7 @@ class DeepseekV2Attention(torch.nn.Module):
prefix: str,
config,
weights: Weights,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@@ -167,13 +168,7 @@ class DeepseekV2Attention(torch.nn.Module):
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.qk_rope_head_dim,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
@@ -459,7 +454,7 @@ class DeepseekV2MoE(nn.Module):
class DeepseekV2Layer(nn.Module):
- def __init__(self, prefix, layer_id, config, weights):
+ def __init__(self, prefix, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
@@ -467,6 +462,7 @@ class DeepseekV2Layer(nn.Module):
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
+ rotary_emb=rotary_emb,
)
if (
@@ -541,6 +537,12 @@ class DeepseekV2Model(torch.nn.Module):
prefix=f"{prefix}.embed_tokens", weights=weights
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.qk_rope_head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
DeepseekV2Layer(
@@ -548,6 +550,7 @@ class DeepseekV2Model(torch.nn.Module):
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
index 8e058093..3a6a974a 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
@@ -169,6 +169,7 @@ class DeepseekV3Attention(torch.nn.Module):
prefix: str,
config,
weights: Weights,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@@ -180,13 +181,7 @@ class DeepseekV3Attention(torch.nn.Module):
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.qk_rope_head_dim,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
@@ -535,7 +530,7 @@ class DeepseekV3MoE(nn.Module):
class DeepseekV3Layer(nn.Module):
- def __init__(self, prefix, layer_id, config, weights):
+ def __init__(self, prefix, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
@@ -543,6 +538,7 @@ class DeepseekV3Layer(nn.Module):
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
+ rotary_emb=rotary_emb,
)
if (
@@ -616,6 +612,12 @@ class DeepseekV3Model(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.qk_rope_head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
@@ -624,6 +626,7 @@ class DeepseekV3Model(torch.nn.Module):
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
index a1a20999..74d9397e 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
@@ -166,7 +166,14 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemma2Attention(torch.nn.Module):
def __init__(
- self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@@ -176,13 +183,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.window_size = config.sliding_window
else:
self.window_size = -1
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
# self.softmax_scale = self.head_size**-0.5
self.softmax_scale = config.query_pre_attn_scalar**-0.5
@@ -354,7 +355,14 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module):
def __init__(
- self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ rotary_emb,
):
super().__init__()
self.self_attn = FlashGemma2Attention(
@@ -364,6 +372,7 @@ class FlashGemma2Layer(nn.Module):
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
+ rotary_emb=rotary_emb,
)
self.mlp = Gemma2MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
@@ -435,6 +444,13 @@ class FlashGemma2Model(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
self.layers = nn.ModuleList(
[
FlashGemma2Layer(
@@ -444,6 +460,7 @@ class FlashGemma2Model(torch.nn.Module):
layer_id=layer_id,
causal=causal,
is_sliding=layer_id % 2 == 0,
+ rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py
index 92f059bc..7b789d30 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py
@@ -119,7 +119,15 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemma3Attention(torch.nn.Module):
def __init__(
- self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ local_rotary_emb,
+ global_rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@@ -130,20 +138,10 @@ class FlashGemma3Attention(torch.nn.Module):
# TODO: remove this hack to support local sliding window
config = copy.deepcopy(config)
config.rope_scaling = dict(rope_type="default")
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=config.head_dim,
- base=config.rope_local_base_freq,
- device=weights.device,
- )
+ self.rotary_emb = local_rotary_emb
else:
self.window_size = -1
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=config.head_dim,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = global_rotary_emb
self.softmax_scale = (
config.query_pre_attn_scalar**-0.5
@@ -336,7 +334,15 @@ class Gemma3MLP(nn.Module):
class FlashGemma3Layer(nn.Module):
def __init__(
- self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ local_rotary_emb,
+ global_rotary_emb,
):
super().__init__()
self.self_attn = FlashGemma3Attention(
@@ -346,6 +352,8 @@ class FlashGemma3Layer(nn.Module):
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
+ local_rotary_emb=local_rotary_emb,
+ global_rotary_emb=global_rotary_emb,
)
self.mlp = Gemma3MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
@@ -417,6 +425,18 @@ class FlashGemma3Model(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
+ local_rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.head_dim,
+ base=config.rope_local_base_freq,
+ device=weights.device,
+ )
+ global_rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
@@ -427,6 +447,8 @@ class FlashGemma3Model(torch.nn.Module):
layer_id=layer_id,
causal=causal,
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
+ local_rotary_emb=local_rotary_emb,
+ global_rotary_emb=global_rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
index 7a2ec22e..5d6dc67c 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
@@ -163,19 +163,12 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module):
- def __init__(self, prefix: str, config, weights, causal: bool):
+ def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
-
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
@@ -300,10 +293,14 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module):
- def __init__(self, prefix: str, config, weights, causal: bool):
+ def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
super().__init__()
self.self_attn = FlashGemmaAttention(
- prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ causal=causal,
+ rotary_emb=rotary_emb,
)
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@@ -359,6 +356,13 @@ class FlashGemmaModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
self.layers = nn.ModuleList(
[
FlashGemmaLayer(
@@ -366,6 +370,7 @@ class FlashGemmaModel(torch.nn.Module):
config=config,
weights=weights,
causal=causal,
+ rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
index 679380a1..1e7a867c 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
@@ -110,6 +110,7 @@ class FlashGPTJAttention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@@ -143,13 +144,7 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
-
- self.rotary_emb = GPTJRotary.static(
- config=config,
- dim=self.rotary_dim,
- base=10000,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
def forward(
self,
@@ -244,10 +239,13 @@ class GPTJMLP(nn.Module):
class FlashGPTJLayer(nn.Module):
- def __init__(self, prefix: str, config, weights):
+ def __init__(self, prefix: str, config, weights, rotary_emb):
super().__init__()
self.self_attn = FlashGPTJAttention(
- prefix=f"{prefix}.attn", config=config, weights=weights
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
)
self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@@ -291,6 +289,12 @@ class FlashGPTJModel(torch.nn.Module):
self.config = config
self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights)
+ rotary_emb = GPTJRotary.static(
+ config=config,
+ dim=config.rotary_dim,
+ base=10000,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
FlashGPTJLayer(
@@ -299,6 +303,7 @@ class FlashGPTJModel(torch.nn.Module):
),
config=config,
weights=weights,
+ rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py
index 3b30f2e0..1db8ad10 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py
@@ -303,7 +303,7 @@ class Llama4TextAttention(FlashLlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, prefix, config, weights, layer_idx):
- super().__init__(layer_idx, prefix, config, weights)
+ super().__init__(layer_idx, prefix, config, weights, None)
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
index 70fcc824..fbfcd39c 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
@@ -133,6 +133,7 @@ class FlashLlamaAttention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@@ -145,13 +146,7 @@ class FlashLlamaAttention(torch.nn.Module):
config, "num_key_value_heads", config.num_attention_heads
)
- if config.model_type != "llama4_text":
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
@@ -376,7 +371,7 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
- def __init__(self, index, prefix, config, weights):
+ def __init__(self, index, prefix, config, weights, rotary_emb):
super().__init__()
with no_fp8(weights):
@@ -385,6 +380,7 @@ class FlashLlamaLayer(nn.Module):
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
+ rotary_emb=rotary_emb,
)
if config.model_type == "phimoe":
@@ -480,6 +476,13 @@ class FlashLlamaModel(torch.nn.Module):
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
@@ -487,6 +490,7 @@ class FlashLlamaModel(torch.nn.Module):
prefix=f"{prefix}.layers.0",
config=config,
weights=weights,
+ rotary_emb=rotary_emb,
)
)
@@ -512,6 +516,7 @@ class FlashLlamaModel(torch.nn.Module):
prefix=(f"{prefix}.layers.{layer_id}"),
config=config,
weights=weights,
+ rotary_emb=rotary_emb,
)
)
@@ -523,6 +528,7 @@ class FlashLlamaModel(torch.nn.Module):
prefix=(f"{prefix}.layers.{last_layer_id}"),
config=config,
weights=weights,
+ rotary_emb=rotary_emb,
)
)
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
index 008df32d..f7aed118 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
@@ -104,7 +104,7 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module):
- def __init__(self, prefix: str, config, weights, layer_id):
+ def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
@@ -117,12 +117,7 @@ class MistralAttention(torch.nn.Module):
else:
self.head_size = self.hidden_size // self.num_heads
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
@@ -300,13 +295,14 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module):
- def __init__(self, prefix: str, config, weights, layer_id):
+ def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
+ rotary_emb=rotary_emb,
)
self.mlp = MistralMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
@@ -366,6 +362,19 @@ class MistralModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
+
+ if getattr(config, "head_dim", None) is not None:
+ head_dim = config.head_dim
+ else:
+ head_dim = config.hidden_size // config.num_attention_heads
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
self.layers = nn.ModuleList(
[
MistralLayer(
@@ -373,6 +382,7 @@ class MistralModel(torch.nn.Module):
config=config,
weights=weights,
layer_id=layer_id,
+ rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
index 4993b444..8c682c7f 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
@@ -188,6 +188,7 @@ class MixtralAttention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.max_past = (
@@ -196,13 +197,7 @@ class MixtralAttention(torch.nn.Module):
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
@@ -345,12 +340,15 @@ class MixtralMoE(nn.Module):
class MixtralLayer(nn.Module):
- def __init__(self, prefix: str, layer_id, config, weights):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = MixtralAttention(
- prefix=f"{prefix}.self_attn", config=config, weights=weights
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
)
moe_layer_cls = (
@@ -416,6 +414,12 @@ class MixtralModel(torch.nn.Module):
weights=weights,
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
MixtralLayer(
@@ -423,6 +427,7 @@ class MixtralModel(torch.nn.Module):
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
index 6e1050b6..8ee1dfa2 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
@@ -99,7 +99,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
class FlashNeoxAttention(torch.nn.Module):
- def __init__(self, config, prefix, weights):
+ def __init__(self, config, prefix, weights, rotary_emb):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
@@ -116,14 +116,7 @@ class FlashNeoxAttention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.rotary_dim,
- base=config.rotary_emb_base,
- device=weights.device,
- )
-
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size ** (-0.5)
self.query_key_value = load_qkv(
@@ -231,7 +224,7 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module):
- def __init__(self, layer_id, config, weights):
+ def __init__(self, layer_id, config, weights, rotary_emb):
super().__init__()
layer_norm_eps = config.layer_norm_eps
@@ -248,7 +241,10 @@ class FlashNeoXLayer(nn.Module):
eps=layer_norm_eps,
)
self.attention = FlashNeoxAttention(
- config, prefix=f"{prefix}.attention", weights=weights
+ config,
+ prefix=f"{prefix}.attention",
+ weights=weights,
+ rotary_emb=rotary_emb,
)
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
@@ -328,9 +324,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
prefix=f"{prefix}.embed_in", weights=weights
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=int(
+ config.rotary_pct * (config.hidden_size // config.num_attention_heads)
+ ),
+ base=config.rotary_emb_base,
+ device=weights.device,
+ )
+
self.layers = nn.ModuleList(
[
- FlashNeoXLayer(layer_id, config, weights)
+ FlashNeoXLayer(layer_id, config, weights, rotary_emb)
for layer_id in range(config.num_hidden_layers)
]
)
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
index 78aaf0d5..d7fc844b 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
@@ -113,6 +113,7 @@ class FlashPhiAttention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@@ -121,13 +122,7 @@ class FlashPhiAttention(torch.nn.Module):
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.rotary_dim,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@@ -259,11 +254,14 @@ class PhiMLP(nn.Module):
class FlashPhiLayer(nn.Module):
- def __init__(self, prefix: str, layer_id, config, weights):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashPhiAttention(
- prefix=f"{prefix}.self_attn", config=config, weights=weights
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
@@ -315,6 +313,16 @@ class FlashPhiModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=int(
+ config.partial_rotary_factor
+ * (config.hidden_size // config.num_attention_heads)
+ ),
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
self.layers = nn.ModuleList(
[
FlashPhiLayer(
@@ -322,6 +330,7 @@ class FlashPhiModel(torch.nn.Module):
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
index ac31e53b..13e1c916 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
@@ -58,6 +58,7 @@ class Qwen2Attention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.max_past = (
@@ -66,13 +67,7 @@ class Qwen2Attention(torch.nn.Module):
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
@@ -199,11 +194,14 @@ class Qwen2MLP(nn.Module):
class Qwen2Layer(nn.Module):
- def __init__(self, prefix, layer_id, config, weights):
+ def __init__(self, prefix, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention(
- prefix=f"{prefix}.self_attn", config=config, weights=weights
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
)
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
@@ -258,6 +256,14 @@ class Qwen2Model(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
self.layers = nn.ModuleList(
[
Qwen2Layer(
@@ -265,6 +271,7 @@ class Qwen2Model(torch.nn.Module):
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py
index 8bd00c13..63ee4c97 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py
@@ -41,7 +41,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config, prefix, weights, layer_idx):
+ def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@@ -54,12 +54,7 @@ class Qwen3Attention(nn.Module):
self.num_heads = config.num_attention_heads
self.attention_dropout = config.attention_dropout
self.softmax_scale = self.head_dim**-0.5
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_dim,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@@ -179,7 +174,7 @@ class Qwen3Attention(nn.Module):
class Qwen3DecoderLayer(nn.Module):
- def __init__(self, config, prefix, weights, layer_idx: int):
+ def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(
@@ -187,6 +182,7 @@ class Qwen3DecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
)
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
self.input_layernorm = FastRMSNorm.load(
@@ -241,6 +237,15 @@ class Qwen3Model(nn.Module):
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
+ head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
@@ -249,6 +254,7 @@ class Qwen3Model(nn.Module):
prefix=f"{prefix}.layers.{layer_idx}",
weights=weights,
layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
)
for layer_idx in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py
index 5e4bc7fa..da474adc 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py
@@ -80,7 +80,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class Qwen3MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config, prefix, weights, layer_idx):
+ def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@@ -108,13 +108,7 @@ class Qwen3MoeAttention(nn.Module):
self.o_proj = FastLinear.load(
config, f"{prefix}.o_proj", weights, bias=config.attention_bias
)
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_dim,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm",
@@ -345,7 +339,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module):
- def __init__(self, config, prefix, weights, layer_idx: int):
+ def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
super().__init__()
self.hidden_size = config.hidden_size
@@ -355,6 +349,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
)
else:
self.self_attn = Qwen3MoeAttention(
@@ -362,6 +357,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
)
moe_layer_cls = (
@@ -433,6 +429,15 @@ class Qwen3MoeModel(nn.Module):
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
+ head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
@@ -441,6 +446,7 @@ class Qwen3MoeModel(nn.Module):
prefix=f"{prefix}.layers.{layer_idx}",
weights=weights,
layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
)
for layer_idx in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
index 06616f85..bd8397f1 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
@@ -134,6 +134,7 @@ class FlashRWAttention(torch.nn.Module):
config,
prefix: str,
weights,
+ rotary_emb,
):
super().__init__()
self.num_heads = config.n_head
@@ -141,13 +142,8 @@ class FlashRWAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rope_theta = config.rope_theta
+ self.rotary_emb = rotary_emb
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=self.rope_theta,
- device=weights.device,
- )
self.softmax_scale = self.head_size ** (-0.5)
if self.num_heads % weights.process_group.size() != 0:
@@ -243,6 +239,7 @@ class FlashRWLargeAttention(torch.nn.Module):
config,
prefix: str,
weights,
+ rotary_emb,
):
super().__init__()
@@ -255,13 +252,8 @@ class FlashRWLargeAttention(torch.nn.Module):
self.head_size = hidden_size // num_heads
self.num_groups = num_groups
self.rope_theta = config.rope_theta
+ self.rotary_emb = rotary_emb
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=self.rope_theta,
- device=weights.device,
- )
self.softmax_scale = self.head_size ** (-0.5)
# self.num_groups = num_heads // (num_heads_kv * 2)
@@ -382,6 +374,7 @@ class FlashRWLayer(nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
@@ -404,6 +397,7 @@ class FlashRWLayer(nn.Module):
config,
prefix=f"{prefix}.self_attention",
weights=weights,
+ rotary_emb=rotary_emb,
)
self.post_attention_layernorm = (
FastLayerNorm.load(
@@ -526,7 +520,7 @@ class FlashRWLayerNorm(nn.Module):
class FlashRWLargeLayer(nn.Module):
- def __init__(self, layer_id, prefix: str, config, weights):
+ def __init__(self, layer_id, prefix: str, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.h.{layer_id}"
@@ -536,6 +530,7 @@ class FlashRWLargeLayer(nn.Module):
config,
prefix=f"{prefix}.self_attention",
weights=weights,
+ rotary_emb=rotary_emb,
)
assert config.parallel_attn, "This version doesn't support non parallel_attn"
@@ -593,11 +588,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.word_embeddings = TensorParallelEmbedding(
prefix=f"{prefix}.word_embeddings", weights=weights
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.n_head,
+ base=config.rope_theta,
+ device=weights.device,
+ )
if config.new_decoder_architecture:
self.h = nn.ModuleList(
[
- FlashRWLargeLayer(layer_id, prefix, config, weights)
+ FlashRWLargeLayer(layer_id, prefix, config, weights, rotary_emb)
for layer_id in range(config.num_hidden_layers)
]
)
@@ -605,7 +606,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
else:
self.h = nn.ModuleList(
[
- FlashRWLayer(layer_id, prefix, config, weights)
+ FlashRWLayer(layer_id, prefix, config, weights, rotary_emb)
for layer_id in range(config.num_hidden_layers)
]
)
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
index 1a749595..45baf4db 100644
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
@@ -180,6 +180,7 @@ class Starcoder2Attention(torch.nn.Module):
prefix: str,
config,
weights,
+ rotary_emb,
):
super().__init__()
self.max_past = (
@@ -188,13 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
-
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config,
- dim=self.head_size,
- base=config.rope_theta,
- device=weights.device,
- )
+ self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
@@ -411,11 +406,15 @@ STARCODER2_MLP_CLASSES = {
class Starcoder2Layer(nn.Module):
- def __init__(self, layer_id, config, weights):
+ def __init__(self, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention(
- prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ index=layer_id,
+ rotary_emb=rotary_emb,
)
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
@@ -481,12 +480,19 @@ class Starcoder2Model(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
self.layers = nn.ModuleList(
[
Starcoder2Layer(
layer_id,
config,
weights,
+ rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py
deleted file mode 100644
index 6ce2054e..00000000
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py
+++ /dev/null
@@ -1,326 +0,0 @@
-# coding=utf-8
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Idefics model configuration"""
-import copy
-
-from transformers import PretrainedConfig
-
-IDEFICS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "HuggingFaceM4/idefics-9b": "https://huggingface.co/HuggingFaceM4/idefics-9b/blob/main/config.json",
- "HuggingFaceM4/idefics-80b": "https://huggingface.co/HuggingFaceM4/idefics-80b/blob/main/config.json",
-}
-
-
-class IdeficsVisionConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
- Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of the Idefics-9B.
- e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- hidden_size (`int`, *optional*, defaults to 768):
- Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`)
- image_size (`int`, *optional*, defaults to 224):
- The size (resolution) of each image.
- intermediate_size (`int`, *optional*, defaults to 5120):
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
- patch_size (`int`, *optional*, defaults to 14):
- The size (resolution) of each patch.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 16):
- Number of attention heads for each attention layer in the Transformer encoder.
- image_num_channels (`int`, *optional*, defaults to `3`):
- Number of image channels.
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
- `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
- layer_norm_eps (`float`, *optional*, defaults to 1e-5):
- The epsilon used by the layer normalization layers.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- initializer_factor (`float`, *optional*, defaults to 1.0):
- A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization
- testing).
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- """
-
- model_type = "idefics"
- attribute_map = {
- "hidden_size": "embed_dim",
- }
-
- def __init__(
- self,
- embed_dim=768,
- image_size=224,
- intermediate_size=5120,
- patch_size=14,
- num_hidden_layers=32,
- num_attention_heads=16,
- num_channels=3,
- hidden_act="gelu",
- layer_norm_eps=1e-5,
- attention_dropout=0.0,
- initializer_range=0.02,
- initializer_factor=1.0,
- **kwargs,
- ):
- self.embed_dim = embed_dim
- self.image_size = image_size
- self.intermediate_size = intermediate_size
- self.patch_size = patch_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.num_channels = num_channels
- self.layer_norm_eps = layer_norm_eps
- self.attention_dropout = attention_dropout
- self.initializer_range = initializer_range
- self.initializer_factor = initializer_factor
- self.hidden_act = hidden_act
-
- super().__init__(**kwargs)
-
-
-class IdeficsPerceiverConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
- Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of the Idefics-9B.
- e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- use_resampler (`bool`, *optional*, defaults to `False`):
- Whether or not to use the resampler
- resampler_n_latents (`int`, *optional*, defaults to ):
- Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
- resampler_depth (`int`, *optional*, defaults to 6):
- Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
- resampler_n_heads (`int`, *optional*, defaults to 16):
- Number of heads in each Transformer block (for multi-headed self-attention).
- resampler_head_dim (`int`, *optional*, defaults to 96):
- Dimensionality of each head projection in the Transformer block.
- qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
- Whether or not to use qk layer norms in perceiver
- """
-
- model_type = "idefics"
-
- def __init__(
- self,
- use_resampler=False,
- resampler_n_latents=64,
- resampler_depth=6,
- resampler_n_heads=16,
- resampler_head_dim=96,
- qk_layer_norms_perceiver=False,
- **kwargs,
- ):
- self.use_resampler = use_resampler
- self.resampler_n_latents = resampler_n_latents
- self.resampler_depth = resampler_depth
- self.resampler_n_heads = resampler_n_heads
- self.resampler_head_dim = resampler_head_dim
- self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
-
- super().__init__(**kwargs)
-
-
-class IdeficsConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
- Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of the Idefics-9B.
- e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- additional_vocab_size (`int`, *optional`, defaults to 0):
- Additional vocabulary size of the model, typically for the special "
" token. Additional vocab tokens
- are always trainable whereas regular vocab tokens can be frozen or not.
- vocab_size (`int`, *optional*, defaults to 32000):
- Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`~IdeficsModel`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 11008):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer encoder.
- dropout (`float`, *optional*, defaults to 0.0):
- The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- alpha_initializer (`str`, *optional*, defaults to `"zeros"`):
- Initialization type for the alphas.
- alphas_initializer_range (`float`, *optional*, defaults to 0.0):
- The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross
- Attention.
- alpha_type (`str`, *optional*, defaults to `"float"`):
- Whether the gating alphas should be vectors or single floats.
- rms_norm_eps (`float`, *optional*, defaults to 1e-6):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*, defaults to 0)
- Padding token id.
- bos_token_id (`int`, *optional*, defaults to 1)
- Beginning of stream token id.
- eos_token_id (`int`, *optional*, defaults to 2)
- End of stream token id.
- tie_word_embeddings(`bool`, *optional*, defaults to `False`):
- Whether to tie weight embeddings
- cross_layer_interval (`int`, *optional*, default to 1)
- Interval for cross attention (from text to image) layers.
- qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k
- freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers
- freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`):
- Exceptions to freezing text layers when `freeze_text_layers` is `True`
- freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head
- freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers
- freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`):
- Exceptions to freezing vision layers when `freeze_vision_layers` is `True`
- use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler
- vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict
- perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict
- Example:
- ```python
- >>> from transformers import IdeficsModel, IdeficsConfig
- >>> # Initializing a Idefics idefics-9b style configuration
- >>> configuration = IdeficsConfig()
- >>> # Initializing a model from the idefics-9b style configuration
- >>> model = IdeficsModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "idefics"
- is_composition = True
-
- def __init__(
- self,
- vocab_size=32000,
- additional_vocab_size=0,
- hidden_size=4096,
- intermediate_size=11008,
- num_hidden_layers=32,
- num_attention_heads=32,
- dropout=0.0,
- hidden_act="silu",
- initializer_range=0.02,
- alpha_initializer="zeros",
- alphas_initializer_range=0.0,
- alpha_type="float",
- rms_norm_eps=1e-6,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
- tie_word_embeddings=False,
- cross_layer_interval=1,
- qk_layer_norms=False,
- freeze_text_layers=True,
- freeze_text_module_exceptions=[],
- freeze_lm_head=False,
- freeze_vision_layers=True,
- freeze_vision_module_exceptions=[],
- use_resampler=False,
- vision_config=None,
- perceiver_config=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.additional_vocab_size = additional_vocab_size
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.dropout = dropout
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.alpha_initializer = alpha_initializer
- self.alphas_initializer_range = alphas_initializer_range
- self.alpha_type = alpha_type
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
-
- self.cross_layer_interval = cross_layer_interval
- self.qk_layer_norms = qk_layer_norms
- self.freeze_vision_layers = freeze_vision_layers
-
- self.freeze_text_layers = freeze_text_layers
- self.freeze_text_module_exceptions = freeze_text_module_exceptions
- self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
- self.freeze_lm_head = freeze_lm_head
-
- self.use_resampler = use_resampler
-
- if perceiver_config is None:
- self.perceiver_config = IdeficsPerceiverConfig()
- elif isinstance(perceiver_config, dict):
- self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config)
- elif isinstance(perceiver_config, IdeficsPerceiverConfig):
- self.perceiver_config = perceiver_config
-
- if vision_config is None:
- self.vision_config = IdeficsVisionConfig()
- elif isinstance(vision_config, dict):
- self.vision_config = IdeficsVisionConfig(**vision_config)
- elif isinstance(vision_config, IdeficsVisionConfig):
- self.vision_config = vision_config
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
-
- # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
- # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
- # updates the config object with `kwargs` from from_pretrained, so during the instantiation
- # of this object many attributes have default values and haven't yet been overridden.
- # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
-
- def to_dict(self):
- """
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
- Returns:
- `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- output = copy.deepcopy(self.__dict__)
-
- output["vision_config"] = self.vision_config.to_dict()
- output["perceiver_config"] = self.perceiver_config.to_dict()
- output["model_type"] = self.__class__.model_type
-
- return output
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py
deleted file mode 100644
index afb8e1f9..00000000
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py
+++ /dev/null
@@ -1,297 +0,0 @@
-# coding=utf-8
-# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Image processor class for Idefics."""
-
-from typing import Callable, Dict, List, Optional, Union, Iterable
-import numpy as np
-
-from PIL import Image
-
-import transformers
-from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
-from transformers.image_transforms import (
- resize,
- to_channel_dimension_format,
- rescale,
- normalize,
-)
-from transformers.image_utils import (
- ChannelDimension,
- ImageInput,
- PILImageResampling,
- make_list_of_images,
- to_numpy_array,
- valid_images,
-)
-from io import BytesIO
-import base64
-import requests
-from transformers import TensorType, is_torch_available
-
-
-IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073]
-IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711]
-
-
-def convert_to_rgb(image):
- # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
- # for transparent images. The call to `alpha_composite` handles this case
- if image.mode == "RGB":
- return image
-
- image_rgba = image.convert("RGBA")
- background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
- alpha_composite = Image.alpha_composite(background, image_rgba)
- alpha_composite = alpha_composite.convert("RGB")
- return alpha_composite
-
-
-class IdeficsImageProcessor(BaseImageProcessor):
- r"""
- Constructs a Idefics image processor.
- Args:
- image_size (`int`, *optional*, defaults to `224`):
- Resize to image size
- image_num_channels (`int`, *optional*, defaults to `3`):
- Number of image channels.
- image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
- Mean to use if normalizing the image. This is a float or list of floats the length of the number of
- channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
- overridden by the `image_mean` parameter in the `preprocess` method.
- image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
- Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
- number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
- Can be overridden by the `image_std` parameter in the `preprocess` method.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- image_size: int = 224,
- image_mean: Optional[Union[float, List[float]]] = None,
- image_std: Optional[Union[float, List[float]]] = None,
- image_num_channels: Optional[int] = 3,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.image_size = image_size
- self.image_num_channels = image_num_channels
- self.image_mean = image_mean
- self.image_std = image_std
-
- def preprocess(
- self,
- images: ImageInput,
- image_num_channels: Optional[int] = 3,
- image_size: Optional[Dict[str, int]] = None,
- image_mean: Optional[Union[float, List[float]]] = None,
- image_std: Optional[Union[float, List[float]]] = None,
- transform: Callable = None,
- **kwargs,
- ) -> TensorType.PYTORCH:
- """
- Preprocess a batch of images.
- Args:
- images (`ImageInput`):
- A list of images to preprocess.
- image_size (`int`, *optional*, defaults to `self.image_size`):
- Resize to image size
- image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`):
- Number of image channels.
- image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
- Mean to use if normalizing the image. This is a float or list of floats the length of the number of
- channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can
- be overridden by the `image_mean` parameter in the `preprocess` method.
- image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
- Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
- number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess`
- method. Can be overridden by the `image_std` parameter in the `preprocess` method.
- transform (`Callable`, *optional*, defaults to `None`):
- A custom transform function that accepts a single image can be passed for training. For example,
- `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is
- assumed - and then a preset of inference-specific transforms will be applied to the images
- Returns:
- a PyTorch tensor of the processed images
- """
- image_size = image_size if image_size is not None else self.image_size
- image_num_channels = (
- image_num_channels
- if image_num_channels is not None
- else self.image_num_channels
- )
- image_mean = image_mean if image_mean is not None else self.image_mean
- image_std = image_std if image_std is not None else self.image_std
- size = (image_size, image_size)
-
- if len(images) == 0:
- return []
-
- images = make_list_of_images(images)
-
- if not valid_images(images):
- raise ValueError(
- "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
- "torch.Tensor, tf.Tensor or jax.ndarray."
- )
-
- # For training a user needs to pass their own set of transforms as a Callable.
- # For reference this is what was used in the original IDEFICS training:
- # transform = transforms.Compose([
- # convert_to_rgb,
- # transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
- # transforms.ToTensor(),
- # transforms.Normalize(mean=image_mean, std=image_std),
- # ])
- if transform is not None:
- if not is_torch_available():
- raise ImportError("To pass in `transform` torch must be installed")
- import torch
-
- images = [transform(x) for x in images]
- return torch.stack(images)
-
- # for inference we do the exact transforms that were used to train IDEFICS
- images = [convert_to_rgb(x) for x in images]
- # further transforms expect numpy arrays
- images = [to_numpy_array(x) for x in images]
- images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
- images = [self.rescale(image=image, scale=1 / 255) for image in images]
- images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
- images = [
- to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images
- ]
- # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
- images = BatchFeature(
- data={"pixel_values": images}, tensor_type=TensorType.PYTORCH
- )["pixel_values"]
-
- return images
-
- def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
- """
- Convert a single or a list of urls into the corresponding `PIL.Image` objects.
- If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
- returned.
- """
- headers = {
- "User-Agent": (
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
- " Safari/537.36"
- )
- }
- if isinstance(image_url_or_urls, list):
- return [self.fetch_images(x) for x in image_url_or_urls]
- elif isinstance(image_url_or_urls, str):
- image = image_url_or_urls
-
- if image.startswith("http://") or image.startswith("https://"):
- response = requests.get(
- image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
- )
- response.raise_for_status()
- content = response.content
- elif image.startswith("data:"):
- # https://stackoverflow.com/questions/17090571/is-there-a-way-to-set-background-image-as-a-base64-encoded-image
- # 
- image = image.split(",")[-1]
- content = base64.b64decode(image)
- else:
- raise ValueError(f"Unrecognized image {image}")
-
- try:
- image = Image.open(BytesIO(content))
- # image.verify()
- except Exception:
- raise ValueError(f"Could not load image from url {image_url_or_urls}")
- return image
- else:
- raise ValueError(
- f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}"
- )
-
- def rescale(
- self,
- image: np.ndarray,
- scale: float,
- data_format: Optional[Union[str, ChannelDimension]] = None,
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
- **kwargs,
- ) -> np.ndarray:
- """
- Rescale an image by a scale factor. image = image * scale.
-
- Args:
- image (`np.ndarray`):
- Image to rescale.
- scale (`float`):
- The scaling factor to rescale pixel values by.
- data_format (`str` or `ChannelDimension`, *optional*):
- The channel dimension format for the output image. If unset, the channel dimension format of the input
- image is used. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- input_data_format (`ChannelDimension` or `str`, *optional*):
- The channel dimension format for the input image. If unset, the channel dimension format is inferred
- from the input image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
-
- Returns:
- `np.ndarray`: The rescaled image.
- """
- # return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
- # requires 4.32
- return rescale(image, scale=scale, data_format=data_format, **kwargs)
-
- def normalize(
- self,
- image: np.ndarray,
- mean: Union[float, Iterable[float]],
- std: Union[float, Iterable[float]],
- data_format: Optional[Union[str, ChannelDimension]] = None,
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
- **kwargs,
- ) -> np.ndarray:
- """
- Normalize an image. image = (image - image_mean) / image_std.
-
- Args:
- image (`np.ndarray`):
- Image to normalize.
- mean (`float` or `Iterable[float]`):
- Image mean to use for normalization.
- std (`float` or `Iterable[float]`):
- Image standard deviation to use for normalization.
- data_format (`str` or `ChannelDimension`, *optional*):
- The channel dimension format for the output image. If unset, the channel dimension format of the input
- image is used. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- input_data_format (`ChannelDimension` or `str`, *optional*):
- The channel dimension format for the input image. If unset, the channel dimension format is inferred
- from the input image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
-
- Returns:
- `np.ndarray`: The normalized image.
- """
- # TODO 4.32
- return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
-
-
-transformers.IdeficsImageProcessor = IdeficsImageProcessor
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py
deleted file mode 100644
index 910e9bcd..00000000
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py
+++ /dev/null
@@ -1,1474 +0,0 @@
-# coding=utf-8
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Idefics model."""
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-from transformers import PreTrainedModel
-from transformers.activations import ACT2FN
-from transformers.modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- dataclass,
-)
-from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
-from text_generation_server.models.custom_modeling.idefics_vision import (
- IdeficsVisionTransformer,
-)
-from text_generation_server.models.custom_modeling.idefics_perceiver import (
- IdeficsPerceiverResampler,
-)
-from text_generation_server.layers import (
- TensorParallelColumnLinear,
- TensorParallelEmbedding,
- TensorParallelRowLinear,
- SpeculativeHead,
- FastLinear,
-)
-from text_generation_server.layers.rotary import PositionRotaryEmbedding
-from loguru import logger
-
-dropout_layer_norm = None
-
-
-@dataclass
-class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
- image_hidden_states: Optional[torch.FloatTensor] = None
-
-
-@dataclass
-class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
- image_hidden_states: Optional[torch.FloatTensor] = None
-
-
-# logger = logging.get_logger(__name__)
-
-# _CONFIG_FOR_DOC = "IdeficsConfig"
-
-# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [
-# "HuggingFaceM4/idefics-9b",
-# "HuggingFaceM4/idefics-80b",
-# # See all Idefics models at https://huggingface.co/models?filter=idefics
-# ]
-
-
-def expand_inputs_for_generation(
- input_ids,
- expand_size=1,
- is_encoder_decoder=False,
- attention_mask=None,
- encoder_outputs=None,
- **model_kwargs,
-):
- expanded_return_idx = (
- torch.arange(input_ids.shape[0])
- .view(-1, 1)
- .repeat(1, expand_size)
- .view(-1)
- .to(input_ids.device)
- )
- input_ids = input_ids.index_select(0, expanded_return_idx)
-
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = token_type_ids.index_select(
- 0, expanded_return_idx
- )
-
- if attention_mask is not None:
- model_kwargs["attention_mask"] = attention_mask.index_select(
- 0, expanded_return_idx
- )
- model_kwargs["image_attention_mask"] = model_kwargs[
- "image_attention_mask"
- ].index_select(0, expanded_return_idx)
- model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(
- 0, expanded_return_idx
- )
-
- if is_encoder_decoder:
- if encoder_outputs is None:
- raise ValueError(
- "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
- )
- encoder_outputs["last_hidden_state"] = (
- encoder_outputs.last_hidden_state.index_select(
- 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
- )
- )
- model_kwargs["encoder_outputs"] = encoder_outputs
- return input_ids, model_kwargs
-
-
-def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
- # must have this key set to at least None
- model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None)
-
- # update past
- if "past_key_values" in outputs:
- model_kwargs["past"] = outputs.past_key_values
- elif "mems" in outputs:
- model_kwargs["past"] = outputs.mems
- elif "past_buckets_states" in outputs:
- model_kwargs["past"] = outputs.past_buckets_states
- else:
- model_kwargs["past"] = None
-
- # update token_type_ids with last value
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = torch.cat(
- [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
- )
-
- # update attention masks
- if not is_encoder_decoder:
- if "attention_mask" in model_kwargs:
- attention_mask = model_kwargs["attention_mask"]
- model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
- dim=-1,
- )
- if "image_attention_mask" in model_kwargs:
- image_attention_mask = model_kwargs["image_attention_mask"]
- last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
- model_kwargs["image_attention_mask"] = last_mask
-
- return model_kwargs
-
-
-def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
- token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
- if past:
- input_ids = input_ids[:, -1].unsqueeze(-1)
- if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
-
- attention_mask = kwargs.get("attention_mask", None)
- position_ids = kwargs.get("position_ids", None)
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past:
- position_ids = position_ids[:, -1].unsqueeze(-1)
-
- pixel_values = kwargs.get("pixel_values", None)
- image_attention_mask = kwargs.get("image_attention_mask", None)
- # if pixel_values is None or image_attention_mask is None:
- # raise ValueError("pixel values and image attention mask cannot be None")
-
- return {
- "input_ids": input_ids,
- "past_key_values": past,
- "use_cache": kwargs.get("use_cache"),
- "position_ids": position_ids,
- "attention_mask": attention_mask,
- "token_type_ids": token_type_ids,
- "pixel_values": pixel_values,
- "image_attention_mask": image_attention_mask,
- }
-
-
-def freeze_model(model, module_exceptions=[]):
- mapping = {
- "LayerNorm": nn.LayerNorm,
- "Linear": nn.Linear,
- "Embedding": nn.Embedding,
- }
- module_exceptions_mapped = [mapping[m] for m in module_exceptions]
- for module in model.modules():
- if module_exceptions and any(
- [isinstance(module, t) for t in module_exceptions_mapped]
- ):
- module.requires_grad_(
- True
- ) # Explicitely setting it to true to avoid any mistakes
- else:
- module.requires_grad_(False)
- return model
-
-
-class IdeficsDecoupledPartialTPEmbedding(nn.Module):
- def __init__(
- self,
- config,
- weights,
- ):
- super().__init__()
- self.num_embeddings = config.vocab_size
- self.weight = TensorParallelEmbedding(
- prefix="model.embed_tokens", weights=weights
- )
- self.additional_weight = nn.Parameter(
- weights.get_tensor("model.embed_tokens.additional_embedding.weight")
- )
-
- def forward(self, input_ids):
- # Clone so that we don't modify the original input_ids later on
- input_ids = input_ids.clone()
- additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
- input_ids_additional_vocab = input_ids[additional_vocab_indices]
- additional_embeddings = torch.nn.functional.embedding(
- input_ids_additional_vocab - self.num_embeddings, self.additional_weight
- )
-
- # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
- input_ids[additional_vocab_indices] = 0
- full_vector = self.weight(input_ids)
-
- # overwrite the records with high indices
- full_vector[additional_vocab_indices] = additional_embeddings
-
- return full_vector
-
-
-class IdeficsDecoupledTensorParallelLinear(nn.Module):
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
- """
- Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
- then it will create `out_additional_features * in_features` additional parameters that are always trained. If
- `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
- """
-
- def __init__(
- self,
- config,
- weights,
- ) -> None:
- super().__init__()
- self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
- self.additional_fc = FastLinear.load(
- config=config,
- prefix="lm_head.additional_fc",
- weights=weights,
- bias=False,
- )
-
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- output, speculative_logits = self.fc(input)
- additional_features = self.additional_fc(input)
- output = torch.cat((output, additional_features), -1)
-
- return output, speculative_logits
-
- def extra_repr(self) -> str:
- """Overwriting `nn.Linear.extra_repr` to include new parameters."""
- return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
- self.in_features,
- self.out_features,
- self.out_additional_features,
- self.bias is not None,
- self.partially_freeze,
- )
-
-
-# Copied from transformers.models.bart.modeling_bart._make_causal_mask
-def _make_causal_mask(
- input_ids_shape: torch.Size,
- dtype: torch.dtype,
- device: torch.device,
- past_key_values_length: int = 0,
-):
- """
- Make causal mask used for bi-directional self-attention.
- """
- bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
- mask_cond = torch.arange(mask.size(-1), device=device)
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
- mask = mask.to(dtype)
-
- if past_key_values_length > 0:
- mask = torch.cat(
- [
- torch.zeros(
- tgt_len, past_key_values_length, dtype=dtype, device=device
- ),
- mask,
- ],
- dim=-1,
- )
- return mask[None, None, :, :].expand(
- bsz, 1, tgt_len, tgt_len + past_key_values_length
- )
-
-
-def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
- )
-
-
-class IdeficsRMSNorm(nn.Module):
- def __init__(self, prefix, weights, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
-
- weight = weights.get_tensor(f"{prefix}.weight")
- self.weight = nn.Parameter(weight)
- self.variance_epsilon = eps
-
- def forward(self, hidden_states, residual=None):
- from vllm_hpu_extension.kernels import rms_norm
-
- orig_shape = hidden_states.shape
- if residual is not None:
- residual += hidden_states.view(residual.shape)
- else:
- residual = hidden_states
- # Note: HPUFusedRMSNorm requires 3D tensors as inputs
- if len(orig_shape) == 2:
- residual = residual.unsqueeze(0)
- x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
- return x.view(orig_shape), residual.view(orig_shape)
-
-
-# this was adapted from LlamaMLP
-class IdeficsMLP(nn.Module):
- def __init__(
- self,
- config,
- prefix,
- weights,
- ):
- super().__init__()
- self.gate_up_proj = TensorParallelColumnLinear.load_multi(
- config,
- prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
- weights=weights,
- dim=0,
- bias=False,
- )
- self.down_proj = TensorParallelRowLinear.load(
- config,
- prefix=f"{prefix}.down_proj",
- weights=weights,
- bias=False,
- )
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, hidden_states):
- gate_up_states = self.gate_up_proj(hidden_states)
- shape = gate_up_states.shape
- gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
- return self.down_proj(
- self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]
- )
-
-
-# this was adapted from LlamaAttention
-class IdeficsAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(
- self,
- config,
- prefix,
- weights,
- qk_layer_norms: bool = False,
- is_cross_attention: bool = False,
- ):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.dropout = config.dropout
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
-
- self.is_cross_attention = is_cross_attention
-
- # if not hasattr(nn.functional, "scaled_dot_product_attention"):
- # raise ValueError("this model requires pytorch 2.0 or higher")
-
- if self.num_heads % weights.process_group.size() != 0:
- raise ValueError(
- f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
- f"and `num_shards`: {weights.process_group.size()}"
- )
- self.num_heads //= weights.process_group.size()
-
- if self.is_cross_attention:
- # kv_input_dim = (
- # self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim
- # )
- self.q_proj = TensorParallelColumnLinear.load(
- config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
- )
- self.k_proj = TensorParallelColumnLinear.load(
- config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
- )
- self.v_proj = TensorParallelColumnLinear.load(
- config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
- )
- else:
- self.qkv = TensorParallelColumnLinear.load_multi(
- config,
- prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
- dim=0,
- weights=weights,
- bias=False,
- )
- self.o_proj = TensorParallelRowLinear.load(
- config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
- )
- self.rotary_emb = PositionRotaryEmbedding.static(
- config=config, dim=self.head_dim, base=10000.0, device=weights.device
- )
- self.qk_layer_norms = qk_layer_norms
- if self.qk_layer_norms:
- self.q_layer_norm = IdeficsRMSNorm(
- prefix=f"{prefix}.q_layer_norm",
- weights=weights,
- eps=config.rms_norm_eps,
- )
- self.k_layer_norm = IdeficsRMSNorm(
- prefix=f"{prefix}.q_layer_norm",
- weights=weights,
- eps=config.rms_norm_eps,
- )
-
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return (
- tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- .contiguous()
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- is_cross_attention = self.is_cross_attention or key_value_states is not None
-
- bsz, q_len, _ = hidden_states.size()
-
- if is_cross_attention:
- query_states = self.q_proj(hidden_states).view(
- bsz, q_len, self.num_heads, self.head_dim
- ) # .transpose(1, 2)
- query_states = query_states.transpose(1, 2)
- (
- _,
- kv_len,
- _,
- ) = (
- key_value_states.size()
- ) # Note that, in this case, `kv_len` == `kv_seq_len`
- key_states = (
- self.k_proj(key_value_states)
- .view(bsz, kv_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- )
- value_states = (
- self.v_proj(key_value_states)
- .view(bsz, kv_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- )
- else:
- qkv = self.qkv(hidden_states)
- query_states, key_states, value_states = qkv.split(
- self.num_heads * self.head_dim, dim=2
- )
-
- query_states = query_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ) # .transpose(1, 2)
- key_states = key_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ) # . transpose(1, 2)
- value_states = value_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ) # .transpose(1, 2)
- kv_seq_len = q_len
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- max_s = max(kv_seq_len, q_len)
- cos, sin = self.rotary_emb.get_cos_sin(
- position_ids.view(-1), max_s, hidden_states.dtype
- )
-
- query_shape = query_states.shape
- key_shape = key_states.shape
- self.rotary_emb(
- query_states.view(-1, *query_shape[2:]),
- key_states.reshape(-1, *key_shape[2:]),
- cos,
- sin,
- )
-
- query_states = query_states.view(query_shape)
- key_states = key_states.view(key_shape)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- # [bsz, nh, t, hd]
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
-
- if self.qk_layer_norms:
- query_states = self.q_layer_norm(query_states)
- key_states = self.k_layer_norm(key_states)
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
-
- attn_output = nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout,
- )
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, q_len, -1)
-
- attn_output = self.o_proj(attn_output)
-
- attn_weights = None
- if output_attentions:
- logger.warning_once(
- "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
- )
-
- return attn_output, attn_weights, past_key_value
-
-
-# this was adapted from LlamaDecoderLayer
-class IdeficsDecoderLayer(nn.Module):
- def __init__(self, layer_id: int, config: IdeficsConfig, weights):
- super().__init__()
- self.process_group = weights.process_group
- self.hidden_size = config.hidden_size
- prefix = f"model.layers.{layer_id}"
- self.self_attn = IdeficsAttention(
- config=config,
- prefix=f"{prefix}.self_attn",
- weights=weights,
- qk_layer_norms=False,
- is_cross_attention=False,
- )
- self.mlp = IdeficsMLP(
- config=config,
- prefix=f"{prefix}.mlp",
- weights=weights,
- )
- self.input_layernorm = IdeficsRMSNorm(
- prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
- )
- self.post_attention_layernorm = IdeficsRMSNorm(
- prefix=f"{prefix}.post_attention_layernorm",
- weights=weights,
- eps=config.rms_norm_eps,
- )
- self.dropout = config.dropout
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- ) -> Tuple[
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
- ]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
-class IdeficsGatedCrossAttentionLayer(nn.Module):
- def __init__(self, layer_id, config: IdeficsConfig, weights):
- super().__init__()
- self.process_group = weights.process_group
- self.hidden_size = config.hidden_size
- prefix = f"model.gated_cross_attn_layers.{layer_id}"
- self.cross_attn = IdeficsAttention(
- config=config,
- prefix=f"{prefix}.cross_attn",
- weights=weights,
- qk_layer_norms=True,
- is_cross_attention=True,
- )
- self.mlp = IdeficsMLP(
- config=config,
- prefix=f"{prefix}.mlp",
- weights=weights,
- )
- self.input_layernorm = IdeficsRMSNorm(
- prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
- )
- self.post_attention_layernorm = IdeficsRMSNorm(
- prefix=f"{prefix}.post_attention_layernorm",
- weights=weights,
- eps=config.rms_norm_eps,
- )
- self.config = config.dropout
-
- self.act_cross_attn = nn.Tanh()
- self.act_dense = nn.Tanh()
-
- self.alpha_cross_attn = nn.Parameter(
- weights.get_tensor(f"{prefix}.alpha_cross_attn")
- )
- self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense"))
-
- if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
- raise ValueError("Alpha parameters not initialized correctly!")
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_hidden_states: Optional[torch.Tensor] = None,
- image_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- no_images: Optional[bool] = False,
- ) -> Tuple[
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
- ]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored
- """
- if image_hidden_states is None:
- raise ValueError(
- "`image_hidden_states` is required for Idefics cross attention module which are visual features to be"
- " conditioned on."
- )
-
- if past_key_value is not None:
- raise NotImplementedError(
- "Past key value states are not implemented for Idefics cross attention module."
- )
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.cross_attn(
- hidden_states=hidden_states,
- key_value_states=image_hidden_states,
- attention_mask=image_attention_mask,
- output_attentions=output_attentions,
- )
- # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
- # when there are no images the model is used in pure language mode
- gate = 0 if no_images else 1
- hidden_states = (
- residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
- )
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
- hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
-LLAMA_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`IdeficsConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-# @add_start_docstrings(
-# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
-# LLAMA_START_DOCSTRING,
-# )
-class IdeficsPreTrainedModel(PreTrainedModel):
- config_class = IdeficsConfig
- # base_model_prefix = "model"
- # supports_gradient_checkpointing = True
- # _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
-
- # def _init_weights(self, module):
- # # important: this ported version of Idefics isn't meant for training from scratch - only
- # # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
- # # base should be used for training from scratch and it contains the correct code.
- # std = self.config.initializer_range
- # if isinstance(module, nn.Linear):
- # module.weight.data.normal_(mean=0.0, std=std)
- # if module.bias is not None:
- # module.bias.data.zero_()
- # elif isinstance(module, nn.Embedding):
- # module.weight.data.normal_(mean=0.0, std=std)
- # if module.padding_idx is not None:
- # module.weight.data[module.padding_idx].zero_()
-
- # def _set_gradient_checkpointing(self, module, value=False):
- # if isinstance(module, IdeficsModel):
- # module.gradient_checkpointing = value
-
-
-# LLAMA_INPUTS_DOCSTRING = r"""
-# Args:
-# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
-# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
-# it.
-
-# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
-# [`PreTrainedTokenizer.__call__`] for details.
-
-# [What are input IDs?](../glossary#input-ids)
-# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
-# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
-# - 1 for tokens that are **not masked**,
-# - 0 for tokens that are **masked**.
-
-# [What are attention masks?](../glossary#attention-mask)
-
-# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
-# [`PreTrainedTokenizer.__call__`] for details.
-
-# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
-# `past_key_values`).
-
-# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
-# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
-# information on the default strategy.
-
-# - 1 indicates the head is **not masked**,
-# - 0 indicates the head is **masked**.
-# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
-# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
-# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
-# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
-# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
-# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
-
-# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
-# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
-# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
-# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
-# `decoder_input_ids` of shape `(batch_size, sequence_length)`.
-# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
-# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
-# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
-# model's internal embedding lookup matrix.
-# use_cache (`bool`, *optional*):
-# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
-# `past_key_values`).
-# output_attentions (`bool`, *optional*):
-# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
-# tensors for more detail.
-# output_hidden_states (`bool`, *optional*):
-# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
-# more detail.
-# return_dict (`bool`, *optional*):
-# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-# """
-
-
-# @add_start_docstrings(
-# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
-# LLAMA_START_DOCSTRING,
-# )
-class IdeficsModel(IdeficsPreTrainedModel):
- # """
- # Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
-
- # Args:
- # config: IdeficsConfig
- # """
-
- def __init__(self, config: IdeficsConfig, weights):
- super().__init__(config)
- self.config = config
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = IdeficsDecoupledPartialTPEmbedding(
- config=config,
- weights=weights,
- )
-
- self.image_size = config.vision_config.image_size
- self.vision_config = config.vision_config
- self.vision_model = IdeficsVisionTransformer(
- prefix="model.vision_model",
- config=config.vision_config,
- weights=weights,
- )
-
- # Perceiver Resampler
- if config.use_resampler:
- perceiver_config = config.perceiver_config
- self.perceiver_resampler = IdeficsPerceiverResampler(
- prefix="model.perceiver_resampler",
- config=config,
- embed_dim=config.vision_config.embed_dim,
- depth=perceiver_config.resampler_depth,
- n_heads=perceiver_config.resampler_n_heads,
- head_dim=perceiver_config.resampler_head_dim,
- n_latents=perceiver_config.resampler_n_latents,
- weights=weights,
- )
-
- self.layers = nn.ModuleList(
- [
- IdeficsDecoderLayer(layer_id, config, weights)
- for layer_id in range(config.num_hidden_layers)
- ]
- )
-
- self.cross_layer_interval = config.cross_layer_interval
- num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
- self.gated_cross_attn_layers = nn.ModuleList(
- [
- IdeficsGatedCrossAttentionLayer(layer_id, config, weights)
- for layer_id in range(num_cross_layers)
- ]
- )
- # self.gradient_checkpointing = False
-
- self.norm = IdeficsRMSNorm(
- prefix="model.norm", weights=weights, eps=config.rms_norm_eps
- )
-
- # self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- # self.post_init()
-
- # self.freeze_relevant_params(config)
-
- # def freeze_relevant_params(self, config=None):
- # if config is None:
- # config = self.config
-
- # if config.freeze_text_layers:
- # self.freeze_text_layers(config.freeze_text_module_exceptions)
-
- # if config.freeze_vision_layers:
- # freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
-
- # def freeze_text_layers(self, module_exceptions=[]):
- # for module in [self.layers, self.norm]:
- # freeze_model(module, module_exceptions=module_exceptions)
-
- # def freeze_vision_layers(self, module_exceptions=[]):
- # freeze_model(self.vision_model, module_exceptions=module_exceptions)
-
- # def get_input_embeddings(self):
- # return self.embed_tokens
-
- # def set_input_embeddings(self, value):
- # self.embed_tokens = value
-
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
- def _prepare_decoder_attention_mask(
- self, attention_mask, input_shape, inputs_embeds, past_key_values_length
- ):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(
- input_shape,
- inputs_embeds.dtype,
- device=inputs_embeds.device,
- past_key_values_length=past_key_values_length,
- )
-
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- ).to(inputs_embeds.device)
- combined_attention_mask = (
- expanded_attn_mask
- if combined_attention_mask is None
- else expanded_attn_mask + combined_attention_mask
- )
-
- return combined_attention_mask
-
- # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- image_hidden_states: Optional[torch.FloatTensor] = None,
- image_embeddings: Optional[torch.FloatTensor] = None,
- image_attention_mask: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPastImage]:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
-
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
- )
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError(
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
- )
-
- seq_length_with_past = seq_length
- past_key_values_length = 0
-
- if past_key_values is not None:
- past_key_values_length = past_key_values[0][0].shape[2]
- seq_length_with_past = seq_length_with_past + past_key_values_length
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- elif position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length,
- seq_length + past_key_values_length,
- dtype=torch.long,
- device=device,
- )
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
- else:
- position_ids = position_ids.view(-1, seq_length).long()
-
- no_images = False
-
- if image_hidden_states is None:
- if pixel_values is None and image_embeddings is None:
- raise ValueError(
- "Either pixel_values and image_embeddings have to be not-None."
- )
-
- elif pixel_values is not None and image_embeddings is not None:
- raise ValueError(
- "You cannot specify both pixel_values and image_embeddings at the same time"
- )
-
- elif pixel_values is not None:
- no_images = len(torch.nonzero(pixel_values)) == 0
- pixel_values = pixel_values.to(
- dtype=self.dtype, device=device
- ) # fp16 compatibility
- batch_size, num_images = pixel_values.shape[:2]
- pixel_values = pixel_values.contiguous().view(
- batch_size * num_images, *pixel_values.shape[2:]
- )
-
- # Get sequence from the vision encoder
- image_hidden_states = self.vision_model(
- pixel_values=pixel_values
- ).last_hidden_state
-
- elif image_embeddings is not None:
- (
- batch_size,
- num_images,
- image_seq_len,
- image_hidden_size,
- ) = image_embeddings.size()
- image_hidden_states = image_embeddings.to(
- dtype=self.dtype, device=input_ids.device
- )
- image_hidden_states = image_hidden_states.view(
- batch_size * num_images, image_seq_len, image_hidden_size
- )
-
- if self.config.use_resampler:
- image_hidden_states = self.perceiver_resampler(image_hidden_states)
- image_seq_len, image_hidden_size = image_hidden_states.size(
- 1
- ), image_hidden_states.size(2)
- image_hidden_states = image_hidden_states.view(
- batch_size, num_images * image_seq_len, image_hidden_size
- )
- else:
- no_images = False
- num_images = pixel_values.shape[1]
- image_seq_len = image_hidden_states.shape[1] // num_images
-
- # # Hack to use the model in full language modeling mode
- # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
- # Make image_attention_mask compatible with hidden states
- text_seq_len = image_attention_mask.size(1)
- image_attention_mask = image_attention_mask.unsqueeze(-1)
- image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
- image_attention_mask = image_attention_mask.view(
- batch_size, text_seq_len, num_images * image_seq_len
- )
- image_batch_size, image_sequence_length, _ = image_hidden_states.size()
- image_hidden_shape = (image_batch_size, image_sequence_length)
- if image_attention_mask is None:
- image_attention_mask = torch.ones(image_hidden_shape, device=device)
- image_attention_mask = self.invert_attention_mask(image_attention_mask)
-
- # if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
- # raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
-
- # if image_hidden_states is not None:
- # else:
- # image_attention_mask = None
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- # embed positions
- if attention_mask is None:
- attention_mask = torch.ones(
- (batch_size, seq_length_with_past),
- dtype=torch.bool,
- device=inputs_embeds.device,
- )
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- )
-
- hidden_states = inputs_embeds
-
- # if self.gradient_checkpointing and self.training:
- # if use_cache:
- # logger.warning_once(
- # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- # )
- # use_cache = False
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
-
- for idx, decoder_layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- past_key_value = (
- past_key_values[idx] if past_key_values is not None else None
- )
-
- def vblock(
- main_block,
- hidden_states,
- attention_mask,
- position_ids,
- past_key_value,
- image_hidden_states,
- image_attention_mask,
- output_attentions,
- use_cache,
- no_images,
- layer_idx,
- cross_layer_interval,
- gated_cross_attn_layers,
- ):
- # TODO(ls): Add cross attention values to respective lists
- if layer_idx % cross_layer_interval == 0:
- xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
- outputs = xblock(
- hidden_states,
- attention_mask=attention_mask,
- image_hidden_states=image_hidden_states,
- image_attention_mask=image_attention_mask,
- output_attentions=output_attentions,
- use_cache=use_cache,
- past_key_value=None, # not implemented
- no_images=no_images,
- )
- hidden_states = outputs[0]
-
- layer_outputs = main_block(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- return layer_outputs
-
- # if self.gradient_checkpointing and self.training:
- # past_key_value = None
- # if use_cache:
- # logger.warning_once(
- # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- # )
- # use_cache = False
-
- # layer_outputs = torch.utils.checkpoint.checkpoint(
- # vblock,
- # decoder_layer,
- # hidden_states,
- # attention_mask,
- # position_ids,
- # past_key_value,
- # image_hidden_states,
- # image_attention_mask,
- # output_attentions,
- # use_cache,
- # no_images,
- # idx,
- # self.cross_layer_interval,
- # self.gated_cross_attn_layers,
- # )
- # else:
- layer_outputs = vblock(
- decoder_layer,
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- image_hidden_states=image_hidden_states,
- image_attention_mask=image_attention_mask,
- output_attentions=output_attentions,
- use_cache=use_cache,
- no_images=no_images,
- layer_idx=idx,
- cross_layer_interval=self.cross_layer_interval,
- gated_cross_attn_layers=self.gated_cross_attn_layers,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
- if v is not None
- )
- return BaseModelOutputWithPastImage(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- image_hidden_states=image_hidden_states,
- )
-
-
-class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
- def __init__(
- self,
- config,
- weights,
- ):
- super().__init__(config)
- self.model = IdeficsModel(
- config=config,
- weights=weights,
- )
-
- self.lm_head = IdeficsDecoupledTensorParallelLinear(
- config=config,
- weights=weights,
- )
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- image_embeddings: Optional[torch.FloatTensor] = None,
- image_hidden_states: Optional[torch.FloatTensor] = None,
- image_attention_mask: Optional[torch.Tensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithPastImage]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, LlamaForCausalLM
-
- >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
-
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
- ```"""
-
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- image_embeddings=image_embeddings,
- image_hidden_states=image_hidden_states,
- image_attention_mask=image_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- hidden_states = outputs[0]
- logits, speculative_logits = self.lm_head(hidden_states)
-
- loss = None
-
- return (
- CausalLMOutputWithPastImage(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states,
- ),
- speculative_logits,
- )
-
- def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
- inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
- unwanted_kwargs = ["token_type_ids"]
- for kwarg in unwanted_kwargs:
- inputs.pop(kwarg, None)
- return inputs
-
- @staticmethod
- def _expand_inputs_for_generation(
- *args,
- **model_kwargs,
- ):
- return expand_inputs_for_generation(*args, **model_kwargs)
-
- @staticmethod
- def _update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=False
- ):
- return update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
- )
-
- @staticmethod
- def _reorder_cache(past, beam_idx):
- reordered_past = ()
- for layer_past in past:
- reordered_past += (
- tuple(
- past_state.index_select(0, beam_idx) for past_state in layer_past
- ),
- )
- return reordered_past
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
deleted file mode 100644
index 6da8045b..00000000
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
-#
-# MIT License
-#
-# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-
-"""
-
-Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
-time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
-that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
-prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
-to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
-
-References:
- - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
- - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
-
-"""
-from typing import Optional, Tuple
-
-import torch
-import torch.nn as nn
-
-from text_generation_server.layers import (
- TensorParallelColumnLinear,
- TensorParallelRowLinear,
-)
-
-EPS = 1e-5
-
-
-class IdeficsPerceiverResampler(nn.Module):
- def __init__(
- self,
- prefix,
- config,
- embed_dim: int,
- depth: int,
- n_heads: int,
- head_dim: int,
- n_latents: int,
- weights,
- ) -> None:
- """
- Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
- MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
- returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
- to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
- Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
-
- Args:
- config (`IdeficsConfig`): config object
- embed_dim (`int`): The size of each embedding vector
- depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
- n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
- head_dim (`int`): Dimensionality of each head projection in the Transformer block.
- n_latents (`int`):
- Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
-
- """
- super().__init__()
- self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (
- embed_dim,
- n_heads,
- head_dim,
- n_latents,
- )
- self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
-
- # Create Latents for Perceiver
- self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents"))
-
- self.intermediate_dim = (
- self.embed_dim * 4
- if not hasattr(config.vision_config, "embed_dim")
- else config.vision_config.embed_dim * 4
- )
- # Create Transformer Blocks
- self.blocks = nn.ModuleList(
- [
- nn.ModuleList(
- [
- IdeficsPerceiverAttention(
- prefix=f"{prefix}.blocks.{layer_id}.0",
- config=config,
- embed_dim=self.embed_dim,
- n_heads=self.n_heads,
- head_dim=self.head_dim,
- qk_layer_norms=self.qk_layer_norms,
- weights=weights,
- ),
- IdeficsMLP(
- prefix=f"{prefix}.blocks.{layer_id}.1",
- intermediate_size=self.intermediate_dim,
- config=config,
- weights=weights,
- ),
- ]
- )
- for layer_id in range(depth)
- ]
- )
- self.layer_norm = nn.LayerNorm.load(
- prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS
- )
-
- def forward(self, context: torch.Tensor) -> torch.Tensor:
- """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
- # einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
- latents = self.latents.repeat(context.shape[0], 1, 1)
-
- # Feed through Perceiver Attention blocks...
- for attn, ff in self.blocks:
- latents = attn(context, latents) + latents
- latents = ff(latents) + latents
-
- return self.layer_norm(latents)
-
-
-class IdeficsPerceiverAttention(nn.Module):
- def __init__(
- self,
- prefix,
- config,
- embed_dim: int,
- n_heads: int,
- head_dim: int,
- qk_layer_norms: bool,
- weights,
- ) -> None:
- """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
- super().__init__()
- self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
- self.qk_layer_norms = qk_layer_norms
- # Normalization & Scaling
- self.context_layer_norm = nn.LayerNorm.load(
- prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS
- )
- self.latents_layer_norm = nn.LayerNorm.load(
- prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS
- )
- if self.qk_layer_norms:
- self.q_layer_norm = nn.LayerNorm.load(
- prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS
- )
- self.k_layer_norm = nn.LayerNorm.load(
- prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS
- )
-
- self.qk_scale = self.head_dim**-0.5
-
- if n_heads % weights.process_group.size() != 0:
- raise ValueError(
- f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
- f"and `num_shards`: {weights.process_group.size()}"
- )
- self.n_heads //= weights.process_group.size()
-
- # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
- self.q_proj = TensorParallelColumnLinear.load(
- config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
- )
- self.k_proj = TensorParallelColumnLinear.load(
- config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
- )
- self.v_proj = TensorParallelColumnLinear.load(
- config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
- )
-
- self.output_proj = TensorParallelRowLinear.load(
- config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False
- )
-
- def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
- """
- Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
-
- Args:
- context (`torch.Tensor`):
- Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
- latents (`torch.Tensor`):
- Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
-
- Returns:
- `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
- from context.
- """
- context = self.context_layer_norm(context)
- latents = self.latents_layer_norm(latents)
- batch_size, seq_length, embed_dim = context.shape[:3]
-
- # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
- # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
- q = self.q_proj(latents)
- k = self.k_proj(torch.cat([context, latents], dim=-2))
- v = self.v_proj(torch.cat([context, latents], dim=-2))
-
- # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
- # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
- # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
- q, k, v = [
- x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(
- 1, 2
- )
- for x in (q, k, v)
- ]
-
- if self.qk_layer_norms:
- q = self.q_layer_norm(q)
- k = self.k_layer_norm(k)
-
- scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
- stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
- attn = stabilized_scores.softmax(dim=-1)
-
- # Attend & project back to output...
- resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
- # einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
- return self.output_proj(resampled.transpose(1, 2).flatten(-2))
-
-
-class IdeficsMLP(nn.Module):
- def __init__(
- self,
- prefix,
- intermediate_size,
- config,
- weights,
- ):
- """Simple MLP block with intermediate_size and embedding size"""
- super().__init__()
- self.embed_dim = config.vision_config.embed_dim
- self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
- self.fc = TensorParallelColumnLinear.load(
- config=config,
- prefix=f"{prefix}.fc",
- weights=weights,
- bias=False,
- )
- self.act = nn.ReLU()
- self.c_proj = TensorParallelRowLinear.load(
- config=config,
- prefix=f"{prefix}.c_proj",
- weights=weights,
- bias=False,
- )
-
- def forward(
- self, hidden_states: Optional[Tuple[torch.FloatTensor]]
- ) -> torch.FloatTensor:
- hidden_states = self.ln(hidden_states)
- hidden_states = self.fc(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.c_proj(hidden_states)
-
- return hidden_states
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py
deleted file mode 100644
index ca61e27d..00000000
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py
+++ /dev/null
@@ -1,443 +0,0 @@
-# coding=utf-8
-# Copyright 2022 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Processor class for IDEFICS.
-"""
-
-from typing import Callable, List, Optional, Union
-from urllib.parse import urlparse
-
-from transformers.feature_extraction_utils import BatchFeature
-from transformers.processing_utils import ProcessorMixin
-from transformers.tokenization_utils_base import (
- BatchEncoding,
- PaddingStrategy,
- TextInput,
- TruncationStrategy,
-)
-from transformers.utils import TensorType, is_torch_available
-
-
-if is_torch_available():
- import torch
-
-
-IMAGE_TOKEN = ""
-
-
-# copied from m4.training.packing
-def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
- # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
-
- # If any of images index are more than num_classes, set them to -1.
- # Words after the max number of images allowed have been seen don't attend on anything
- if num_classes != -1:
- incremental_mask[incremental_mask >= num_classes] = -1
-
- negatives = incremental_mask == -1
- incremental_mask[negatives] = 0
- attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
- attn_mask[negatives, :] = 0
- return attn_mask
-
-
-# copied from m4.training.packing
-def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
- image_attention_mask = torch.full_like(input_ids, fill_value=-1)
- next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
- image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
- eod_token_id = tokenizer.eos_token_id
- for batch_idx in range(input_ids.size(0)):
- count = -1
- seen_eod = False
- for idx, token_id in enumerate(input_ids[batch_idx]):
- if token_id == image_token_id:
- count += 1
- image_attention_mask[batch_idx][idx] = count
- seen_eod = False
- else:
- image_attention_mask[batch_idx][idx] = count
-
- if seen_eod:
- image_attention_mask[batch_idx][idx] = -1
-
- if token_id == eod_token_id:
- seen_eod = True
-
- for batch_idx in range(input_ids.size(0)):
- count = -1
- seen_eod = False
- for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
- token_id = input_ids[batch_idx][idx]
- if token_id == image_token_id:
- count += 1
- next_image_attention_mask[batch_idx][idx] = count
- seen_eod = False
- else:
- next_image_attention_mask[batch_idx][idx] = count
-
- if token_id == eod_token_id:
- seen_eod = True
-
- if seen_eod:
- next_image_attention_mask[batch_idx][idx] = -1
-
- non_negative_indices = next_image_attention_mask[batch_idx] != -1
- next_image_attention_mask[batch_idx][non_negative_indices] -= count
- next_image_attention_mask[batch_idx][non_negative_indices] *= -1
-
- return image_attention_mask, next_image_attention_mask
-
-
-def is_url(string):
- """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
- invalidated the url"""
- if " " in string:
- return False
- result = urlparse(string)
- return all([result.scheme, result.netloc])
-
-
-def is_image(string):
- """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
- invalidated the url"""
- return is_url(string) or string.startswith("data:")
-
-
-class IdeficsProcessor(ProcessorMixin):
- r"""
- Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor.
-
- [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See
- the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
-
- Args:
- image_processor (`IdeficsImageProcessor`):
- An instance of [`IdeficsImageProcessor`]. The image processor is a required input.
- tokenizer (`LlamaTokenizerFast`):
- An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
- image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
- """
-
- attributes = ["image_processor", "tokenizer"]
- image_processor_class = "IdeficsImageProcessor"
- tokenizer_class = "LlamaTokenizerFast"
-
- def __init__(
- self,
- image_processor,
- tokenizer=None,
- image_size=224,
- add_end_of_utterance_token=None,
- **kwargs,
- ):
- if image_processor is None:
- raise ValueError("You need to specify an `image_processor`.")
- if tokenizer is None:
- raise ValueError("You need to specify a `tokenizer`.")
-
- super().__init__(image_processor, tokenizer)
- self.current_processor = self.image_processor
- self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
-
- self.default_image_dims = (
- self.image_processor.image_num_channels,
- self.image_processor.image_size,
- self.image_processor.image_size,
- )
-
- self.tokenizer_was_trained_with_end_of_utterance_token = (
- True
- if ""
- in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
- else False
- )
-
- def __call__(
- self,
- prompts: Union[List[TextInput], List[List[TextInput]]],
- padding: Union[bool, str, PaddingStrategy] = False,
- truncation: Union[bool, str, TruncationStrategy] = None,
- max_length: Optional[int] = None,
- transform: Callable = None,
- add_eos_token=False,
- add_end_of_utterance_token=None,
- debug=False,
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
- ) -> BatchEncoding:
- """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
- the model was trained on and prepares the image pixel values for the model to process.
-
- Args:
- prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
- either a single prompt or a batched list of prompts - see the detailed description immediately after
- the end of the arguments doc section.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- truncation (`bool`, *optional*):
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
- transform (`Callable`, *optional*):
- A custom transform function that accepts a single image can be passed for training. For example,
- `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
- set of transforms will be applied to the images
- add_eos_token (`bool`, *optional*, defaults to `False`):
- Adds `eos_token` at the end of the final prompt if True`
- add_end_of_utterance_token (`bool`, *optional*)
- Whether to automatically add `` after each prompt's text input (unless followed by an
- image). If `None` the tokenizer will be checked instead and if this token is found in
- `additional_special_tokens` then the value will be `True`.
- debug (`bool`, *optional*, defaults to `False`):
- `True` value will help debug prompt generation by dumping useful information
- return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
- The type of tensors to return. Can be one of:
- - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
-
- Returns:
- a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be
- directly passed to `model.generate`
-
- Detailed explanation:
-
- Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
-
- An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
-
- When the processor encounters an image it'll inject ``
- entry into the prompt.
-
- Example:
-
- ```python
- checkpoint = "HuggingFaceM4/idefics-9b"
- processor = AutoProcessor.from_pretrained(checkpoint)
- url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg"
- img = processor.image_processor.fetch_images([url])[0]
-
- prompts = [
- "User:",
- img,
- "Describe this image.\nAssistant: An image of two kittens in grass.\n",
- "User:",
- "https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg",
- "Describe this image.\nAssistant:",
- ]
-
- inputs = processor(prompts, return_tensors="pt")
- generated_ids = model.generate(**inputs, max_length=100)
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- ```
-
- In this example the `prompts` will be converted into:
-
- ```
- User:Describe this image.
- Assistant: An image of two kittens in grass.
- User:Describe this image.
- Assistant:'
- ```
-
- and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the
- `pixel_values` dict entry of the return value.
-
- This example also examplifies that images can be passed as objects or as text urls. It can be seen that the
- first image is passed as object and the second one as a url.
-
- To do training do:
-
- ```python
- image_transform = transforms.Compose(
- [
- transforms.RandomResizedCrop(
- (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
- ),
- transforms.ToTensor(),
- transforms.Normalize(mean=self.image_mean, std=self.image_std),
- ]
- )
- inputs = processor(prompts, transform=image_transform, return_tensors="pt")
- ```
-
- In order to help debug prompt generation enable `debug=True` which will show you what's happening.
-
- """
-
- # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
- if add_end_of_utterance_token is None:
- add_end_of_utterance_token = (
- self.tokenizer_was_trained_with_end_of_utterance_token
- )
-
- # turn non-batched prompts into batched
- if not any(isinstance(i, list) for i in prompts):
- prompts = [prompts]
-
- fake_token = ""
- image_token = ""
- end_of_utterance_token = ""
-
- def image_tokens(last_was_image):
- if last_was_image:
- return image_token + fake_token
- else:
- return fake_token + image_token + fake_token
-
- all_texts = []
- all_images = []
- for sample in prompts:
- # the model was trained on samples starting with
- full_text = f"{self.tokenizer.bos_token}"
-
- # an image can either be an image object in the item or the url, everything else is a verbatim prompt text
- image_objects = []
- last_was_image = False
- last_was_text = False
- for i, item in enumerate(sample):
- if i > 0:
- last_was_text = True if not last_was_image else False
-
- if isinstance(item, str):
- item = item.strip(" ")
- if is_image(item):
- image = self.image_processor.fetch_images(item)
- full_text += image_tokens(last_was_image)
- image_objects.append(image)
- last_was_image = True
- else:
- # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!)
- if add_end_of_utterance_token and last_was_text:
- full_text += end_of_utterance_token
- full_text += item
- last_was_image = False
- else:
- # must be an image obj
- full_text += image_tokens(last_was_image)
- image_objects.append(item)
- last_was_image = True
-
- if add_eos_token:
- full_text += self.tokenizer.eos_token
-
- if debug is True:
- print(f"{full_text=}")
-
- image_objects = self.image_processor(image_objects, transform=transform)
-
- text_encoding = self.tokenizer(
- text=full_text,
- add_special_tokens=False,
- padding=padding,
- truncation=truncation,
- max_length=max_length,
- )
-
- all_texts.append(text_encoding["input_ids"])
- all_images.append(image_objects)
-
- max_seq_len = max(len(x) for x in all_texts)
-
- # max_num_images has to be at least 1 even when there are no images
- max_num_images = max(len(x) for x in all_images)
- max_num_images = max(1, max_num_images)
-
- at_least_one_image = sum(len(x) for x in all_images) > 0
- output_input_ids = []
- output_images = []
- output_attention_masks = []
- for text, images in zip(all_texts, all_images):
- padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len
- unpadded_seq_len = len(text)
- start = max_seq_len - unpadded_seq_len
- padded_input_ids[start:] = text[:max_seq_len]
-
- attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
- attention_mask[start:] = 1
-
- image_count = padded_input_ids.count(self.image_token_id)
- local_max_num_images = min(image_count, max_num_images)
-
- current_images = images[:local_max_num_images]
-
- if len(current_images) > 0:
- padded_image_tensor = torch.zeros(
- max_num_images, *current_images.size()[1:]
- )
- padded_image_tensor[: current_images.size(0)] = current_images
- else:
- padded_image_tensor = torch.zeros(
- max_num_images, *self.default_image_dims
- )
-
- output_images.append(padded_image_tensor)
- output_input_ids.append(torch.tensor(padded_input_ids))
-
- output_attention_masks.append(attention_mask)
-
- output_input_ids = torch.stack(output_input_ids)
- output_images = torch.stack(output_images)
- output_attention_masks = torch.stack(output_attention_masks)
-
- if at_least_one_image:
- image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
- output_input_ids, self.tokenizer
- )
- image_attention_mask = incremental_to_binary_attention_mask(
- image_attention_mask, num_classes=max_num_images
- )
- else:
- # in full language mode we set the image mask to all-0s
- image_attention_mask = torch.zeros(
- output_input_ids.shape[0],
- output_input_ids.shape[1],
- 1,
- dtype=torch.bool,
- )
-
- return BatchFeature(
- data={
- "input_ids": output_input_ids,
- "attention_mask": output_attention_masks,
- "pixel_values": output_images,
- "image_attention_mask": image_attention_mask,
- }
- )
-
- def batch_decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
- refer to the docstring of this method for more information.
- """
- return self.tokenizer.batch_decode(*args, **kwargs)
-
- def decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
- the docstring of this method for more information.
- """
- return self.tokenizer.decode(*args, **kwargs)
-
- @property
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- image_processor_input_names = self.image_processor.model_input_names
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py
deleted file mode 100644
index 7d2051e0..00000000
--- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py
+++ /dev/null
@@ -1,529 +0,0 @@
-# coding=utf-8
-# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
-
-
-from dataclasses import dataclass
-from typing import Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-from transformers.activations import ACT2FN
-from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
-from transformers.utils import (
- ModelOutput,
- logging,
-)
-from text_generation_server.layers import (
- TensorParallelColumnLinear,
- TensorParallelRowLinear,
- TensorParallelEmbedding,
-)
-
-logger = logging.get_logger(__name__)
-
-
-@dataclass
-class IdeficsVisionModelOutput(ModelOutput):
- """
- Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
-
- Args:
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
- The image embeddings obtained by applying the projection layer to the pooler_output.
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- image_embeds: Optional[torch.FloatTensor] = None
- last_hidden_state: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
-# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
-class IdeficsVisionEmbeddings(nn.Module):
- def __init__(self, prefix, config, weights):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
-
- self.class_embedding = nn.Parameter(
- weights.get_tensor(f"{prefix}.class_embedding")
- )
-
- self.patch_embedding = nn.Conv2d.load_no_bias(
- prefix=f"{prefix}.patch_embedding",
- weights=weights,
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- )
-
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches + 1
- self.position_embedding = TensorParallelEmbedding(
- prefix="model.vision_model.embeddings.position_embedding", weights=weights
- )
- self.position_ids = (
- torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
- )
-
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
- batch_size = pixel_values.shape[0]
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(
- pixel_values.to(dtype=target_dtype)
- ) # shape = [*, width, grid, grid]
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
-
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- embeddings = embeddings + self.position_embedding(self.position_ids)
- return embeddings
-
-
-# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision
-class IdeficsVisionAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, prefix, config, weights):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
-
- if self.num_heads % weights.process_group.size() != 0:
- raise ValueError(
- f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
- f"and `num_shards`: {weights.process_group.size()}"
- )
- self.num_heads = self.num_heads // weights.process_group.size()
- self.embed_dim = self.embed_dim // weights.process_group.size()
-
- self.k_proj = TensorParallelColumnLinear.load(
- config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
- )
- self.v_proj = TensorParallelColumnLinear.load(
- config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
- )
- self.q_proj = TensorParallelColumnLinear.load(
- config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
- )
- self.out_proj = TensorParallelRowLinear.load(
- config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
- )
-
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return (
- tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- .contiguous()
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- causal_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scale
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.view(*proj_shape)
- value_states = value_states.view(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- # apply the causal_attention_mask first
- if causal_attention_mask is not None:
- if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
- f" {causal_attention_mask.size()}"
- )
- attn_weights = (
- attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- + causal_attention_mask
- )
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = (
- attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- + attention_mask
- )
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if output_attentions:
- # this operation is a bit akward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(
- bsz, self.num_heads, tgt_len, src_len
- )
- attn_weights = attn_weights_reshaped.view(
- bsz * self.num_heads, tgt_len, src_len
- )
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(
- attn_weights, p=self.dropout, training=self.training
- )
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped
-
-
-# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
-class IdeficsVisionMLP(nn.Module):
- def __init__(self, prefix, config, weights):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = TensorParallelColumnLinear.load(
- config, prefix=f"{prefix}.fc1", weights=weights, bias=True
- )
- self.fc2 = TensorParallelRowLinear.load(
- config, prefix=f"{prefix}.fc2", weights=weights, bias=True
- )
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision
-class IdeficsVisionEncoderLayer(nn.Module):
- def __init__(self, prefix, config, weights):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = IdeficsVisionAttention(
- prefix=f"{prefix}.self_attn", config=config, weights=weights
- )
- self.layer_norm1 = nn.LayerNorm.load(
- prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
- )
- self.mlp = IdeficsVisionMLP(
- prefix=f"{prefix}.mlp", config=config, weights=weights
- )
- self.layer_norm2 = nn.LayerNorm.load(
- prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- causal_attention_mask: torch.Tensor,
- output_attentions: Optional[bool] = False,
- ) -> Tuple[torch.FloatTensor]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- `(config.encoder_attention_heads,)`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- """
- residual = hidden_states
-
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- causal_attention_mask=causal_attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = residual + hidden_states
-
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (attn_weights,)
-
- return outputs
-
-
-# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision
-class IdeficsVisionEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`IdeficsVisionEncoderLayer`].
-
- Args:
- config: IdeficsVisionConfig
- """
-
- def __init__(self, prefix, config, weights):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList(
- [
- IdeficsVisionEncoderLayer(
- prefix=f"{prefix}.encoder.layers.{layer_id}",
- config=config,
- weights=weights,
- )
- for layer_id in range(config.num_hidden_layers)
- ]
- )
- # self.gradient_checkpointing = False
-
- def forward(
- self,
- inputs_embeds,
- attention_mask: Optional[torch.Tensor] = None,
- causal_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutput]:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
- causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Causal mask for the text model. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
-
- encoder_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
-
- hidden_states = inputs_embeds
- for idx, encoder_layer in enumerate(self.layers):
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- # if self.gradient_checkpointing and self.training:
-
- # def create_custom_forward(module):
- # def custom_forward(*inputs):
- # return module(*inputs, output_attentions)
-
- # return custom_forward
-
- # layer_outputs = torch.utils.checkpoint.checkpoint(
- # create_custom_forward(encoder_layer),
- # hidden_states,
- # attention_mask,
- # causal_attention_mask,
- # )
- # else:
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask,
- causal_attention_mask,
- output_attentions=output_attentions,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
-
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
-
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, encoder_states, all_attentions]
- if v is not None
- )
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=encoder_states,
- attentions=all_attentions,
- )
-
-
-# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
-class IdeficsVisionTransformer(nn.Module):
- def __init__(self, prefix, config, weights):
- super().__init__()
- self.config = config
-
- self.embeddings = IdeficsVisionEmbeddings(
- prefix=f"{prefix}.embeddings", config=config, weights=weights
- )
- self.pre_layrnorm = nn.LayerNorm.load(
- prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
- )
- self.encoder = IdeficsVisionEncoder(
- prefix=prefix, config=config, weights=weights
- )
- self.post_layernorm = nn.LayerNorm.load(
- prefix=f"{prefix}.post_layernorm",
- weights=weights,
- eps=config.layer_norm_eps,
- )
-
- # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
- r"""
- Returns:
-
- """
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
-
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
-
- hidden_states = self.embeddings(pixel_values)
- hidden_states = self.pre_layrnorm(hidden_states)
-
- encoder_outputs = self.encoder(
- inputs_embeds=hidden_states,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- last_hidden_state = encoder_outputs[0]
- pooled_output = last_hidden_state[:, 0, :]
- pooled_output = self.post_layernorm(pooled_output)
-
- if not return_dict:
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
-
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )