mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-24 01:50:17 +00:00
[gaudi] Refine rope memory, do not need to keep sin/cos cache per layer (#3274)
This commit is contained in:
parent
238fbd4d50
commit
719907410b
@ -36,7 +36,9 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -268,9 +270,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
|
||||
)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
|
||||
@ -298,6 +298,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
@ -351,6 +354,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
if (
|
||||
@ -592,9 +598,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
position_ids: torch.Tensor,
|
||||
):
|
||||
slen = position_ids.shape[0]
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
|
||||
)
|
||||
|
||||
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||
|
@ -160,18 +160,14 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = CohereRotary.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
@ -325,11 +321,14 @@ class CohereMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashCohereLayer(nn.Module):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
self.self_attn = FlashCohereAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
@ -385,6 +384,12 @@ class FlashCohereModel(torch.nn.Module):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
rotary_emb = CohereRotary.static(
|
||||
config=config,
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashCohereLayer(
|
||||
@ -392,6 +397,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -263,6 +263,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
@ -270,12 +271,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
self.hidden_size = config.d_model
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.attn_config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
@ -370,13 +366,17 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_1 = FastLayerNorm.load_no_bias(
|
||||
prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
|
||||
)
|
||||
self.self_attn = DbrxAttention(
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.norm_2 = FastLayerNorm.load_no_bias(
|
||||
prefix=f"{prefix}.norm_2",
|
||||
@ -601,12 +601,15 @@ class DenseMoE(nn.Module):
|
||||
|
||||
|
||||
class DbrxLayer(nn.Module):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.blocks.{layer_id}"
|
||||
|
||||
self.attn = DbrxNormAttentionNorm(
|
||||
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
|
||||
prefix=f"{prefix}.norm_attn_norm",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
||||
@ -649,6 +652,12 @@ class DbrxModel(torch.nn.Module):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.wte", weights=weights
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.d_model // config.n_heads,
|
||||
base=config.attn_config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@ -657,6 +666,7 @@ class DbrxModel(torch.nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
|
@ -156,6 +156,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights: Weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -167,13 +168,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.value_head_size = config.v_head_dim
|
||||
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.qk_rope_head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
mscale = get_mscale(
|
||||
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||
@ -459,7 +454,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
|
||||
class DeepseekV2Layer(nn.Module):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
|
||||
@ -467,6 +462,7 @@ class DeepseekV2Layer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -541,6 +537,12 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.qk_rope_head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DeepseekV2Layer(
|
||||
@ -548,6 +550,7 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -169,6 +169,7 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights: Weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -180,13 +181,7 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.value_head_size = config.v_head_dim
|
||||
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.qk_rope_head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
mscale = get_mscale(
|
||||
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||
@ -535,7 +530,7 @@ class DeepseekV3MoE(nn.Module):
|
||||
|
||||
|
||||
class DeepseekV3Layer(nn.Module):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
|
||||
@ -543,6 +538,7 @@ class DeepseekV3Layer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -616,6 +612,12 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.qk_rope_head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@ -624,6 +626,7 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -166,7 +166,14 @@ def _load_gqa(config, prefix: str, weights):
|
||||
|
||||
class FlashGemma2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
layer_id,
|
||||
causal: bool,
|
||||
is_sliding: bool,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -176,13 +183,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
self.window_size = config.sliding_window
|
||||
else:
|
||||
self.window_size = -1
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
# self.softmax_scale = self.head_size**-0.5
|
||||
self.softmax_scale = config.query_pre_attn_scalar**-0.5
|
||||
@ -354,7 +355,14 @@ class Gemma2MLP(nn.Module):
|
||||
|
||||
class FlashGemma2Layer(nn.Module):
|
||||
def __init__(
|
||||
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
layer_id,
|
||||
causal: bool,
|
||||
is_sliding: bool,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGemma2Attention(
|
||||
@ -364,6 +372,7 @@ class FlashGemma2Layer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
causal=causal,
|
||||
is_sliding=is_sliding,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = Gemma2MLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||
@ -435,6 +444,13 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGemma2Layer(
|
||||
@ -444,6 +460,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
layer_id=layer_id,
|
||||
causal=causal,
|
||||
is_sliding=layer_id % 2 == 0,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -119,7 +119,15 @@ def _load_gqa(config, prefix: str, weights):
|
||||
|
||||
class FlashGemma3Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
layer_id,
|
||||
causal: bool,
|
||||
is_sliding: bool,
|
||||
local_rotary_emb,
|
||||
global_rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -130,20 +138,10 @@ class FlashGemma3Attention(torch.nn.Module):
|
||||
# TODO: remove this hack to support local sliding window
|
||||
config = copy.deepcopy(config)
|
||||
config.rope_scaling = dict(rope_type="default")
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_local_base_freq,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = local_rotary_emb
|
||||
else:
|
||||
self.window_size = -1
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = global_rotary_emb
|
||||
|
||||
self.softmax_scale = (
|
||||
config.query_pre_attn_scalar**-0.5
|
||||
@ -336,7 +334,15 @@ class Gemma3MLP(nn.Module):
|
||||
|
||||
class FlashGemma3Layer(nn.Module):
|
||||
def __init__(
|
||||
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
layer_id,
|
||||
causal: bool,
|
||||
is_sliding: bool,
|
||||
local_rotary_emb,
|
||||
global_rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGemma3Attention(
|
||||
@ -346,6 +352,8 @@ class FlashGemma3Layer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
causal=causal,
|
||||
is_sliding=is_sliding,
|
||||
local_rotary_emb=local_rotary_emb,
|
||||
global_rotary_emb=global_rotary_emb,
|
||||
)
|
||||
self.mlp = Gemma3MLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||
@ -417,6 +425,18 @@ class FlashGemma3Model(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
local_rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_local_base_freq,
|
||||
device=weights.device,
|
||||
)
|
||||
global_rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@ -427,6 +447,8 @@ class FlashGemma3Model(torch.nn.Module):
|
||||
layer_id=layer_id,
|
||||
causal=causal,
|
||||
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
|
||||
local_rotary_emb=local_rotary_emb,
|
||||
global_rotary_emb=global_rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -163,19 +163,12 @@ def _load_gqa(config, prefix: str, weights):
|
||||
|
||||
|
||||
class FlashGemmaAttention(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.head_dim
|
||||
self.causal = causal
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
@ -300,10 +293,14 @@ class GemmaMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashGemmaLayer(nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGemmaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
causal=causal,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
@ -359,6 +356,13 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGemmaLayer(
|
||||
@ -366,6 +370,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
config=config,
|
||||
weights=weights,
|
||||
causal=causal,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -110,6 +110,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -143,13 +144,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
self.rotary_emb = GPTJRotary.static(
|
||||
config=config,
|
||||
dim=self.rotary_dim,
|
||||
base=10000,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -244,10 +239,13 @@ class GPTJMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashGPTJLayer(nn.Module):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
def __init__(self, prefix: str, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGPTJAttention(
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
@ -291,6 +289,12 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights)
|
||||
rotary_emb = GPTJRotary.static(
|
||||
config=config,
|
||||
dim=config.rotary_dim,
|
||||
base=10000,
|
||||
device=weights.device,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGPTJLayer(
|
||||
@ -299,6 +303,7 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -303,7 +303,7 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, prefix, config, weights, layer_idx):
|
||||
super().__init__(layer_idx, prefix, config, weights)
|
||||
super().__init__(layer_idx, prefix, config, weights, None)
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(
|
||||
|
@ -133,6 +133,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -145,13 +146,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
)
|
||||
|
||||
if config.model_type != "llama4_text":
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
# `config.attention_multiplier` is used in Granite
|
||||
self.softmax_scale = getattr(
|
||||
@ -376,7 +371,7 @@ class LlamaMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
def __init__(self, index, prefix, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
|
||||
with no_fp8(weights):
|
||||
@ -385,6 +380,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
if config.model_type == "phimoe":
|
||||
@ -480,6 +476,13 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
# Skip fp8 quant for first and last layers
|
||||
self.layers = nn.ModuleList()
|
||||
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
|
||||
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
with no_fp8(weights):
|
||||
self.layers.append(
|
||||
FlashLlamaLayer(
|
||||
@ -487,6 +490,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
prefix=f"{prefix}.layers.0",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
)
|
||||
|
||||
@ -512,6 +516,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
)
|
||||
|
||||
@ -523,6 +528,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
prefix=(f"{prefix}.layers.{last_layer_id}"),
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -104,7 +104,7 @@ class MistralConfig(PretrainedConfig):
|
||||
|
||||
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
@ -117,12 +117,7 @@ class MistralAttention(torch.nn.Module):
|
||||
else:
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
@ -300,13 +295,14 @@ class MistralMLP(nn.Module):
|
||||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
|
||||
super().__init__()
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = MistralMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||
@ -366,6 +362,19 @@ class MistralModel(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
if getattr(config, "head_dim", None) is not None:
|
||||
head_dim = config.head_dim
|
||||
else:
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MistralLayer(
|
||||
@ -373,6 +382,7 @@ class MistralModel(torch.nn.Module):
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -188,6 +188,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
@ -196,13 +197,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
@ -345,12 +340,15 @@ class MixtralMoE(nn.Module):
|
||||
|
||||
|
||||
class MixtralLayer(nn.Module):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
|
||||
self.self_attn = MixtralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
moe_layer_cls = (
|
||||
@ -416,6 +414,12 @@ class MixtralModel(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralLayer(
|
||||
@ -423,6 +427,7 @@ class MixtralModel(torch.nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -99,7 +99,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
|
||||
|
||||
class FlashNeoxAttention(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, prefix, weights, rotary_emb):
|
||||
super().__init__()
|
||||
num_heads = config.num_attention_heads
|
||||
hidden_size = config.hidden_size
|
||||
@ -116,14 +116,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.rotary_dim,
|
||||
base=config.rotary_emb_base,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.query_key_value = load_qkv(
|
||||
@ -231,7 +224,7 @@ class FlashMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashNeoXLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
|
||||
layer_norm_eps = config.layer_norm_eps
|
||||
@ -248,7 +241,10 @@ class FlashNeoXLayer(nn.Module):
|
||||
eps=layer_norm_eps,
|
||||
)
|
||||
self.attention = FlashNeoxAttention(
|
||||
config, prefix=f"{prefix}.attention", weights=weights
|
||||
config,
|
||||
prefix=f"{prefix}.attention",
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
|
||||
@ -328,9 +324,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
prefix=f"{prefix}.embed_in", weights=weights
|
||||
)
|
||||
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=int(
|
||||
config.rotary_pct * (config.hidden_size // config.num_attention_heads)
|
||||
),
|
||||
base=config.rotary_emb_base,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashNeoXLayer(layer_id, config, weights)
|
||||
FlashNeoXLayer(layer_id, config, weights, rotary_emb)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
@ -113,6 +113,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -121,13 +122,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.rotary_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
@ -259,11 +254,14 @@ class PhiMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashPhiLayer(nn.Module):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
self.self_attn = FlashPhiAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
@ -315,6 +313,16 @@ class FlashPhiModel(torch.nn.Module):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=int(
|
||||
config.partial_rotary_factor
|
||||
* (config.hidden_size // config.num_attention_heads)
|
||||
),
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashPhiLayer(
|
||||
@ -322,6 +330,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -58,6 +58,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
@ -66,13 +67,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
@ -199,11 +194,14 @@ class Qwen2MLP(nn.Module):
|
||||
|
||||
|
||||
class Qwen2Layer(nn.Module):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
self.self_attn = Qwen2Attention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
@ -258,6 +256,14 @@ class Qwen2Model(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2Layer(
|
||||
@ -265,6 +271,7 @@ class Qwen2Model(torch.nn.Module):
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -41,7 +41,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
class Qwen3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config, prefix, weights, layer_idx):
|
||||
def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@ -54,12 +54,7 @@ class Qwen3Attention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
@ -179,7 +174,7 @@ class Qwen3Attention(nn.Module):
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
def __init__(self, config, prefix, weights, layer_idx: int):
|
||||
def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Qwen3Attention(
|
||||
@ -187,6 +182,7 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
@ -241,6 +237,15 @@ class Qwen3Model(nn.Module):
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@ -249,6 +254,7 @@ class Qwen3Model(nn.Module):
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -80,7 +80,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class Qwen3MoeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config, prefix, weights, layer_idx):
|
||||
def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@ -108,13 +108,7 @@ class Qwen3MoeAttention(nn.Module):
|
||||
self.o_proj = FastLinear.load(
|
||||
config, f"{prefix}.o_proj", weights, bias=config.attention_bias
|
||||
)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.q_norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.q_norm",
|
||||
@ -345,7 +339,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
def __init__(self, config, prefix, weights, layer_idx: int):
|
||||
def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
@ -355,6 +349,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
self.self_attn = Qwen3MoeAttention(
|
||||
@ -362,6 +357,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
moe_layer_cls = (
|
||||
@ -433,6 +429,15 @@ class Qwen3MoeModel(nn.Module):
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@ -441,6 +446,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -134,6 +134,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
config,
|
||||
prefix: str,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.n_head
|
||||
@ -141,13 +142,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=self.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
@ -243,6 +239,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
config,
|
||||
prefix: str,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -255,13 +252,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.head_size = hidden_size // num_heads
|
||||
self.num_groups = num_groups
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=self.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
# self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
@ -382,6 +374,7 @@ class FlashRWLayer(nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -404,6 +397,7 @@ class FlashRWLayer(nn.Module):
|
||||
config,
|
||||
prefix=f"{prefix}.self_attention",
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
self.post_attention_layernorm = (
|
||||
FastLayerNorm.load(
|
||||
@ -526,7 +520,7 @@ class FlashRWLayerNorm(nn.Module):
|
||||
|
||||
|
||||
class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(self, layer_id, prefix: str, config, weights):
|
||||
def __init__(self, layer_id, prefix: str, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.h.{layer_id}"
|
||||
|
||||
@ -536,6 +530,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||
config,
|
||||
prefix=f"{prefix}.self_attention",
|
||||
weights=weights,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
assert config.parallel_attn, "This version doesn't support non parallel_attn"
|
||||
|
||||
@ -593,11 +588,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
self.word_embeddings = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.word_embeddings", weights=weights
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.hidden_size // config.n_head,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(layer_id, prefix, config, weights)
|
||||
FlashRWLargeLayer(layer_id, prefix, config, weights, rotary_emb)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@ -605,7 +606,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
else:
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(layer_id, prefix, config, weights)
|
||||
FlashRWLayer(layer_id, prefix, config, weights, rotary_emb)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
@ -180,6 +180,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
@ -188,13 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
@ -411,11 +406,15 @@ STARCODER2_MLP_CLASSES = {
|
||||
|
||||
|
||||
class Starcoder2Layer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, layer_id, config, weights, rotary_emb):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = Starcoder2Attention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
index=layer_id,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||
@ -481,12 +480,19 @@ class Starcoder2Model(torch.nn.Module):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Starcoder2Layer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
rotary_emb,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -1,326 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Idefics model configuration"""
|
||||
import copy
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
IDEFICS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"HuggingFaceM4/idefics-9b": "https://huggingface.co/HuggingFaceM4/idefics-9b/blob/main/config.json",
|
||||
"HuggingFaceM4/idefics-80b": "https://huggingface.co/HuggingFaceM4/idefics-80b/blob/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class IdeficsVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
|
||||
Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Idefics-9B.
|
||||
e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`)
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
intermediate_size (`int`, *optional*, defaults to 5120):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
image_num_channels (`int`, *optional*, defaults to `3`):
|
||||
Number of image channels.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization
|
||||
testing).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
"""
|
||||
|
||||
model_type = "idefics"
|
||||
attribute_map = {
|
||||
"hidden_size": "embed_dim",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
image_size=224,
|
||||
intermediate_size=5120,
|
||||
patch_size=14,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.embed_dim = embed_dim
|
||||
self.image_size = image_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.patch_size = patch_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class IdeficsPerceiverConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
|
||||
Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Idefics-9B.
|
||||
e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
use_resampler (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use the resampler
|
||||
resampler_n_latents (`int`, *optional*, defaults to ):
|
||||
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
|
||||
resampler_depth (`int`, *optional*, defaults to 6):
|
||||
Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
|
||||
resampler_n_heads (`int`, *optional*, defaults to 16):
|
||||
Number of heads in each Transformer block (for multi-headed self-attention).
|
||||
resampler_head_dim (`int`, *optional*, defaults to 96):
|
||||
Dimensionality of each head projection in the Transformer block.
|
||||
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use qk layer norms in perceiver
|
||||
"""
|
||||
|
||||
model_type = "idefics"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_resampler=False,
|
||||
resampler_n_latents=64,
|
||||
resampler_depth=6,
|
||||
resampler_n_heads=16,
|
||||
resampler_head_dim=96,
|
||||
qk_layer_norms_perceiver=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.use_resampler = use_resampler
|
||||
self.resampler_n_latents = resampler_n_latents
|
||||
self.resampler_depth = resampler_depth
|
||||
self.resampler_n_heads = resampler_n_heads
|
||||
self.resampler_head_dim = resampler_head_dim
|
||||
self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class IdeficsConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
|
||||
Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Idefics-9B.
|
||||
e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
additional_vocab_size (`int`, *optional`, defaults to 0):
|
||||
Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
|
||||
are always trainable whereas regular vocab tokens can be frozen or not.
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`~IdeficsModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
alpha_initializer (`str`, *optional*, defaults to `"zeros"`):
|
||||
Initialization type for the alphas.
|
||||
alphas_initializer_range (`float`, *optional*, defaults to 0.0):
|
||||
The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross
|
||||
Attention.
|
||||
alpha_type (`str`, *optional*, defaults to `"float"`):
|
||||
Whether the gating alphas should be vectors or single floats.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0)
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1)
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2)
|
||||
End of stream token id.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
cross_layer_interval (`int`, *optional*, default to 1)
|
||||
Interval for cross attention (from text to image) layers.
|
||||
qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k
|
||||
freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers
|
||||
freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`):
|
||||
Exceptions to freezing text layers when `freeze_text_layers` is `True`
|
||||
freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head
|
||||
freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers
|
||||
freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`):
|
||||
Exceptions to freezing vision layers when `freeze_vision_layers` is `True`
|
||||
use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler
|
||||
vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict
|
||||
perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import IdeficsModel, IdeficsConfig
|
||||
>>> # Initializing a Idefics idefics-9b style configuration
|
||||
>>> configuration = IdeficsConfig()
|
||||
>>> # Initializing a model from the idefics-9b style configuration
|
||||
>>> model = IdeficsModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "idefics"
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
additional_vocab_size=0,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
dropout=0.0,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
alpha_initializer="zeros",
|
||||
alphas_initializer_range=0.0,
|
||||
alpha_type="float",
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
cross_layer_interval=1,
|
||||
qk_layer_norms=False,
|
||||
freeze_text_layers=True,
|
||||
freeze_text_module_exceptions=[],
|
||||
freeze_lm_head=False,
|
||||
freeze_vision_layers=True,
|
||||
freeze_vision_module_exceptions=[],
|
||||
use_resampler=False,
|
||||
vision_config=None,
|
||||
perceiver_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.additional_vocab_size = additional_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.dropout = dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.alpha_initializer = alpha_initializer
|
||||
self.alphas_initializer_range = alphas_initializer_range
|
||||
self.alpha_type = alpha_type
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
|
||||
self.cross_layer_interval = cross_layer_interval
|
||||
self.qk_layer_norms = qk_layer_norms
|
||||
self.freeze_vision_layers = freeze_vision_layers
|
||||
|
||||
self.freeze_text_layers = freeze_text_layers
|
||||
self.freeze_text_module_exceptions = freeze_text_module_exceptions
|
||||
self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
|
||||
self.freeze_lm_head = freeze_lm_head
|
||||
|
||||
self.use_resampler = use_resampler
|
||||
|
||||
if perceiver_config is None:
|
||||
self.perceiver_config = IdeficsPerceiverConfig()
|
||||
elif isinstance(perceiver_config, dict):
|
||||
self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config)
|
||||
elif isinstance(perceiver_config, IdeficsPerceiverConfig):
|
||||
self.perceiver_config = perceiver_config
|
||||
|
||||
if vision_config is None:
|
||||
self.vision_config = IdeficsVisionConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
self.vision_config = IdeficsVisionConfig(**vision_config)
|
||||
elif isinstance(vision_config, IdeficsVisionConfig):
|
||||
self.vision_config = vision_config
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
|
||||
# PretrainedConfig.from_dict first instantiates the class with the config dict and only then
|
||||
# updates the config object with `kwargs` from from_pretrained, so during the instantiation
|
||||
# of this object many attributes have default values and haven't yet been overridden.
|
||||
# Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["perceiver_config"] = self.perceiver_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
|
||||
return output
|
@ -1,297 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for Idefics."""
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union, Iterable
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import transformers
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from transformers.image_transforms import (
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
rescale,
|
||||
normalize,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import requests
|
||||
from transformers import TensorType, is_torch_available
|
||||
|
||||
|
||||
IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||
IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
|
||||
def convert_to_rgb(image):
|
||||
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
|
||||
# for transparent images. The call to `alpha_composite` handles this case
|
||||
if image.mode == "RGB":
|
||||
return image
|
||||
|
||||
image_rgba = image.convert("RGBA")
|
||||
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
|
||||
alpha_composite = Image.alpha_composite(background, image_rgba)
|
||||
alpha_composite = alpha_composite.convert("RGB")
|
||||
return alpha_composite
|
||||
|
||||
|
||||
class IdeficsImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a Idefics image processor.
|
||||
Args:
|
||||
image_size (`int`, *optional*, defaults to `224`):
|
||||
Resize to image size
|
||||
image_num_channels (`int`, *optional*, defaults to `3`):
|
||||
Number of image channels.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 224,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
image_num_channels: Optional[int] = 3,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.image_num_channels = image_num_channels
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
image_num_channels: Optional[int] = 3,
|
||||
image_size: Optional[Dict[str, int]] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
transform: Callable = None,
|
||||
**kwargs,
|
||||
) -> TensorType.PYTORCH:
|
||||
"""
|
||||
Preprocess a batch of images.
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
A list of images to preprocess.
|
||||
image_size (`int`, *optional*, defaults to `self.image_size`):
|
||||
Resize to image size
|
||||
image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`):
|
||||
Number of image channels.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can
|
||||
be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess`
|
||||
method. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
transform (`Callable`, *optional*, defaults to `None`):
|
||||
A custom transform function that accepts a single image can be passed for training. For example,
|
||||
`torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is
|
||||
assumed - and then a preset of inference-specific transforms will be applied to the images
|
||||
Returns:
|
||||
a PyTorch tensor of the processed images
|
||||
"""
|
||||
image_size = image_size if image_size is not None else self.image_size
|
||||
image_num_channels = (
|
||||
image_num_channels
|
||||
if image_num_channels is not None
|
||||
else self.image_num_channels
|
||||
)
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
size = (image_size, image_size)
|
||||
|
||||
if len(images) == 0:
|
||||
return []
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
# For training a user needs to pass their own set of transforms as a Callable.
|
||||
# For reference this is what was used in the original IDEFICS training:
|
||||
# transform = transforms.Compose([
|
||||
# convert_to_rgb,
|
||||
# transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=image_mean, std=image_std),
|
||||
# ])
|
||||
if transform is not None:
|
||||
if not is_torch_available():
|
||||
raise ImportError("To pass in `transform` torch must be installed")
|
||||
import torch
|
||||
|
||||
images = [transform(x) for x in images]
|
||||
return torch.stack(images)
|
||||
|
||||
# for inference we do the exact transforms that were used to train IDEFICS
|
||||
images = [convert_to_rgb(x) for x in images]
|
||||
# further transforms expect numpy arrays
|
||||
images = [to_numpy_array(x) for x in images]
|
||||
images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
|
||||
images = [self.rescale(image=image, scale=1 / 255) for image in images]
|
||||
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
|
||||
images = [
|
||||
to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images
|
||||
]
|
||||
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
|
||||
images = BatchFeature(
|
||||
data={"pixel_values": images}, tensor_type=TensorType.PYTORCH
|
||||
)["pixel_values"]
|
||||
|
||||
return images
|
||||
|
||||
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
|
||||
"""
|
||||
Convert a single or a list of urls into the corresponding `PIL.Image` objects.
|
||||
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
||||
returned.
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
|
||||
" Safari/537.36"
|
||||
)
|
||||
}
|
||||
if isinstance(image_url_or_urls, list):
|
||||
return [self.fetch_images(x) for x in image_url_or_urls]
|
||||
elif isinstance(image_url_or_urls, str):
|
||||
image = image_url_or_urls
|
||||
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(
|
||||
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
|
||||
)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
elif image.startswith("data:"):
|
||||
# https://stackoverflow.com/questions/17090571/is-there-a-way-to-set-background-image-as-a-base64-encoded-image
|
||||
# 
|
||||
image = image.split(",")[-1]
|
||||
content = base64.b64decode(image)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized image {image}")
|
||||
|
||||
try:
|
||||
image = Image.open(BytesIO(content))
|
||||
# image.verify()
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load image from url {image_url_or_urls}")
|
||||
return image
|
||||
else:
|
||||
raise ValueError(
|
||||
f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}"
|
||||
)
|
||||
|
||||
def rescale(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
scale: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale an image by a scale factor. image = image * scale.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to rescale.
|
||||
scale (`float`):
|
||||
The scaling factor to rescale pixel values by.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The rescaled image.
|
||||
"""
|
||||
# return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
|
||||
# requires 4.32
|
||||
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
|
||||
def normalize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
mean: Union[float, Iterable[float]],
|
||||
std: Union[float, Iterable[float]],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Normalize an image. image = (image - image_mean) / image_std.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to normalize.
|
||||
mean (`float` or `Iterable[float]`):
|
||||
Image mean to use for normalization.
|
||||
std (`float` or `Iterable[float]`):
|
||||
Image standard deviation to use for normalization.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The normalized image.
|
||||
"""
|
||||
# TODO 4.32
|
||||
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||
|
||||
|
||||
transformers.IdeficsImageProcessor = IdeficsImageProcessor
|
File diff suppressed because it is too large
Load Diff
@ -1,276 +0,0 @@
|
||||
# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
|
||||
time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
|
||||
that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
|
||||
prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
|
||||
to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
|
||||
|
||||
References:
|
||||
- DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
|
||||
- Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
|
||||
|
||||
"""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
EPS = 1e-5
|
||||
|
||||
|
||||
class IdeficsPerceiverResampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix,
|
||||
config,
|
||||
embed_dim: int,
|
||||
depth: int,
|
||||
n_heads: int,
|
||||
head_dim: int,
|
||||
n_latents: int,
|
||||
weights,
|
||||
) -> None:
|
||||
"""
|
||||
Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
|
||||
MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
|
||||
returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
|
||||
to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
|
||||
Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
|
||||
|
||||
Args:
|
||||
config (`IdeficsConfig`): config object
|
||||
embed_dim (`int`): The size of each embedding vector
|
||||
depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
|
||||
n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
|
||||
head_dim (`int`): Dimensionality of each head projection in the Transformer block.
|
||||
n_latents (`int`):
|
||||
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (
|
||||
embed_dim,
|
||||
n_heads,
|
||||
head_dim,
|
||||
n_latents,
|
||||
)
|
||||
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
|
||||
|
||||
# Create Latents for Perceiver
|
||||
self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents"))
|
||||
|
||||
self.intermediate_dim = (
|
||||
self.embed_dim * 4
|
||||
if not hasattr(config.vision_config, "embed_dim")
|
||||
else config.vision_config.embed_dim * 4
|
||||
)
|
||||
# Create Transformer Blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
nn.ModuleList(
|
||||
[
|
||||
IdeficsPerceiverAttention(
|
||||
prefix=f"{prefix}.blocks.{layer_id}.0",
|
||||
config=config,
|
||||
embed_dim=self.embed_dim,
|
||||
n_heads=self.n_heads,
|
||||
head_dim=self.head_dim,
|
||||
qk_layer_norms=self.qk_layer_norms,
|
||||
weights=weights,
|
||||
),
|
||||
IdeficsMLP(
|
||||
prefix=f"{prefix}.blocks.{layer_id}.1",
|
||||
intermediate_size=self.intermediate_dim,
|
||||
config=config,
|
||||
weights=weights,
|
||||
),
|
||||
]
|
||||
)
|
||||
for layer_id in range(depth)
|
||||
]
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
|
||||
def forward(self, context: torch.Tensor) -> torch.Tensor:
|
||||
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
|
||||
# einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
|
||||
latents = self.latents.repeat(context.shape[0], 1, 1)
|
||||
|
||||
# Feed through Perceiver Attention blocks...
|
||||
for attn, ff in self.blocks:
|
||||
latents = attn(context, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
return self.layer_norm(latents)
|
||||
|
||||
|
||||
class IdeficsPerceiverAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix,
|
||||
config,
|
||||
embed_dim: int,
|
||||
n_heads: int,
|
||||
head_dim: int,
|
||||
qk_layer_norms: bool,
|
||||
weights,
|
||||
) -> None:
|
||||
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
|
||||
super().__init__()
|
||||
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
|
||||
self.qk_layer_norms = qk_layer_norms
|
||||
# Normalization & Scaling
|
||||
self.context_layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
self.latents_layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
if self.qk_layer_norms:
|
||||
self.q_layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
self.k_layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
|
||||
self.qk_scale = self.head_dim**-0.5
|
||||
|
||||
if n_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.n_heads //= weights.process_group.size()
|
||||
|
||||
# Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
||||
)
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
||||
)
|
||||
|
||||
self.output_proj = TensorParallelRowLinear.load(
|
||||
config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False
|
||||
)
|
||||
|
||||
def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
|
||||
|
||||
Args:
|
||||
context (`torch.Tensor`):
|
||||
Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
|
||||
latents (`torch.Tensor`):
|
||||
Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
|
||||
from context.
|
||||
"""
|
||||
context = self.context_layer_norm(context)
|
||||
latents = self.latents_layer_norm(latents)
|
||||
batch_size, seq_length, embed_dim = context.shape[:3]
|
||||
|
||||
# Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
|
||||
# Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
|
||||
q = self.q_proj(latents)
|
||||
k = self.k_proj(torch.cat([context, latents], dim=-2))
|
||||
v = self.v_proj(torch.cat([context, latents], dim=-2))
|
||||
|
||||
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
|
||||
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
|
||||
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
|
||||
q, k, v = [
|
||||
x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
if self.qk_layer_norms:
|
||||
q = self.q_layer_norm(q)
|
||||
k = self.k_layer_norm(k)
|
||||
|
||||
scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
|
||||
stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
|
||||
attn = stabilized_scores.softmax(dim=-1)
|
||||
|
||||
# Attend & project back to output...
|
||||
resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
|
||||
# einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
|
||||
return self.output_proj(resampled.transpose(1, 2).flatten(-2))
|
||||
|
||||
|
||||
class IdeficsMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix,
|
||||
intermediate_size,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
"""Simple MLP block with intermediate_size and embedding size"""
|
||||
super().__init__()
|
||||
self.embed_dim = config.vision_config.embed_dim
|
||||
self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
|
||||
self.fc = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.fc",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
self.c_proj = TensorParallelRowLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: Optional[Tuple[torch.FloatTensor]]
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = self.ln(hidden_states)
|
||||
hidden_states = self.fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
|
||||
return hidden_states
|
@ -1,443 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for IDEFICS.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.tokenization_utils_base import (
|
||||
BatchEncoding,
|
||||
PaddingStrategy,
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from transformers.utils import TensorType, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
# copied from m4.training.packing
|
||||
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
|
||||
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
|
||||
|
||||
# If any of images index are more than num_classes, set them to -1.
|
||||
# Words after the max number of images allowed have been seen don't attend on anything
|
||||
if num_classes != -1:
|
||||
incremental_mask[incremental_mask >= num_classes] = -1
|
||||
|
||||
negatives = incremental_mask == -1
|
||||
incremental_mask[negatives] = 0
|
||||
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
|
||||
attn_mask[negatives, :] = 0
|
||||
return attn_mask
|
||||
|
||||
|
||||
# copied from m4.training.packing
|
||||
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
|
||||
image_attention_mask = torch.full_like(input_ids, fill_value=-1)
|
||||
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
eod_token_id = tokenizer.eos_token_id
|
||||
for batch_idx in range(input_ids.size(0)):
|
||||
count = -1
|
||||
seen_eod = False
|
||||
for idx, token_id in enumerate(input_ids[batch_idx]):
|
||||
if token_id == image_token_id:
|
||||
count += 1
|
||||
image_attention_mask[batch_idx][idx] = count
|
||||
seen_eod = False
|
||||
else:
|
||||
image_attention_mask[batch_idx][idx] = count
|
||||
|
||||
if seen_eod:
|
||||
image_attention_mask[batch_idx][idx] = -1
|
||||
|
||||
if token_id == eod_token_id:
|
||||
seen_eod = True
|
||||
|
||||
for batch_idx in range(input_ids.size(0)):
|
||||
count = -1
|
||||
seen_eod = False
|
||||
for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
|
||||
token_id = input_ids[batch_idx][idx]
|
||||
if token_id == image_token_id:
|
||||
count += 1
|
||||
next_image_attention_mask[batch_idx][idx] = count
|
||||
seen_eod = False
|
||||
else:
|
||||
next_image_attention_mask[batch_idx][idx] = count
|
||||
|
||||
if token_id == eod_token_id:
|
||||
seen_eod = True
|
||||
|
||||
if seen_eod:
|
||||
next_image_attention_mask[batch_idx][idx] = -1
|
||||
|
||||
non_negative_indices = next_image_attention_mask[batch_idx] != -1
|
||||
next_image_attention_mask[batch_idx][non_negative_indices] -= count
|
||||
next_image_attention_mask[batch_idx][non_negative_indices] *= -1
|
||||
|
||||
return image_attention_mask, next_image_attention_mask
|
||||
|
||||
|
||||
def is_url(string):
|
||||
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
|
||||
invalidated the url"""
|
||||
if " " in string:
|
||||
return False
|
||||
result = urlparse(string)
|
||||
return all([result.scheme, result.netloc])
|
||||
|
||||
|
||||
def is_image(string):
|
||||
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
|
||||
invalidated the url"""
|
||||
return is_url(string) or string.startswith("data:")
|
||||
|
||||
|
||||
class IdeficsProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor.
|
||||
|
||||
[`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See
|
||||
the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor (`IdeficsImageProcessor`):
|
||||
An instance of [`IdeficsImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`LlamaTokenizerFast`):
|
||||
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
|
||||
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "IdeficsImageProcessor"
|
||||
tokenizer_class = "LlamaTokenizerFast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor,
|
||||
tokenizer=None,
|
||||
image_size=224,
|
||||
add_end_of_utterance_token=None,
|
||||
**kwargs,
|
||||
):
|
||||
if image_processor is None:
|
||||
raise ValueError("You need to specify an `image_processor`.")
|
||||
if tokenizer is None:
|
||||
raise ValueError("You need to specify a `tokenizer`.")
|
||||
|
||||
super().__init__(image_processor, tokenizer)
|
||||
self.current_processor = self.image_processor
|
||||
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
|
||||
self.default_image_dims = (
|
||||
self.image_processor.image_num_channels,
|
||||
self.image_processor.image_size,
|
||||
self.image_processor.image_size,
|
||||
)
|
||||
|
||||
self.tokenizer_was_trained_with_end_of_utterance_token = (
|
||||
True
|
||||
if "<end_of_utterance>"
|
||||
in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
|
||||
else False
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompts: Union[List[TextInput], List[List[TextInput]]],
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
transform: Callable = None,
|
||||
add_eos_token=False,
|
||||
add_end_of_utterance_token=None,
|
||||
debug=False,
|
||||
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
||||
) -> BatchEncoding:
|
||||
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that
|
||||
the model was trained on and prepares the image pixel values for the model to process.
|
||||
|
||||
Args:
|
||||
prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
|
||||
either a single prompt or a batched list of prompts - see the detailed description immediately after
|
||||
the end of the arguments doc section.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (`bool`, *optional*):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
transform (`Callable`, *optional*):
|
||||
A custom transform function that accepts a single image can be passed for training. For example,
|
||||
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
|
||||
set of transforms will be applied to the images
|
||||
add_eos_token (`bool`, *optional*, defaults to `False`):
|
||||
Adds `eos_token` at the end of the final prompt if True`
|
||||
add_end_of_utterance_token (`bool`, *optional*)
|
||||
Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an
|
||||
image). If `None` the tokenizer will be checked instead and if this token is found in
|
||||
`additional_special_tokens` then the value will be `True`.
|
||||
debug (`bool`, *optional*, defaults to `False`):
|
||||
`True` value will help debug prompt generation by dumping useful information
|
||||
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
|
||||
The type of tensors to return. Can be one of:
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
|
||||
Returns:
|
||||
a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be
|
||||
directly passed to `model.generate`
|
||||
|
||||
Detailed explanation:
|
||||
|
||||
Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
|
||||
|
||||
An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
|
||||
|
||||
When the processor encounters an image it'll inject `<fake_token_around_image><image><fake_token_around_image>`
|
||||
entry into the prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
checkpoint = "HuggingFaceM4/idefics-9b"
|
||||
processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg"
|
||||
img = processor.image_processor.fetch_images([url])[0]
|
||||
|
||||
prompts = [
|
||||
"User:",
|
||||
img,
|
||||
"Describe this image.\nAssistant: An image of two kittens in grass.\n",
|
||||
"User:",
|
||||
"https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg",
|
||||
"Describe this image.\nAssistant:",
|
||||
]
|
||||
|
||||
inputs = processor(prompts, return_tensors="pt")
|
||||
generated_ids = model.generate(**inputs, max_length=100)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
```
|
||||
|
||||
In this example the `prompts` will be converted into:
|
||||
|
||||
```
|
||||
<s>User:<fake_token_around_image><image><fake_token_around_image>Describe this image.
|
||||
Assistant: An image of two kittens in grass.
|
||||
User:<fake_token_around_image><image><fake_token_around_image>Describe this image.
|
||||
Assistant:'
|
||||
```
|
||||
|
||||
and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the
|
||||
`pixel_values` dict entry of the return value.
|
||||
|
||||
This example also examplifies that images can be passed as objects or as text urls. It can be seen that the
|
||||
first image is passed as object and the second one as a url.
|
||||
|
||||
To do training do:
|
||||
|
||||
```python
|
||||
image_transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(
|
||||
(w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=self.image_mean, std=self.image_std),
|
||||
]
|
||||
)
|
||||
inputs = processor(prompts, transform=image_transform, return_tensors="pt")
|
||||
```
|
||||
|
||||
In order to help debug prompt generation enable `debug=True` which will show you what's happening.
|
||||
|
||||
"""
|
||||
|
||||
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
|
||||
if add_end_of_utterance_token is None:
|
||||
add_end_of_utterance_token = (
|
||||
self.tokenizer_was_trained_with_end_of_utterance_token
|
||||
)
|
||||
|
||||
# turn non-batched prompts into batched
|
||||
if not any(isinstance(i, list) for i in prompts):
|
||||
prompts = [prompts]
|
||||
|
||||
fake_token = "<fake_token_around_image>"
|
||||
image_token = "<image>"
|
||||
end_of_utterance_token = "<end_of_utterance>"
|
||||
|
||||
def image_tokens(last_was_image):
|
||||
if last_was_image:
|
||||
return image_token + fake_token
|
||||
else:
|
||||
return fake_token + image_token + fake_token
|
||||
|
||||
all_texts = []
|
||||
all_images = []
|
||||
for sample in prompts:
|
||||
# the model was trained on samples starting with <s>
|
||||
full_text = f"{self.tokenizer.bos_token}"
|
||||
|
||||
# an image can either be an image object in the item or the url, everything else is a verbatim prompt text
|
||||
image_objects = []
|
||||
last_was_image = False
|
||||
last_was_text = False
|
||||
for i, item in enumerate(sample):
|
||||
if i > 0:
|
||||
last_was_text = True if not last_was_image else False
|
||||
|
||||
if isinstance(item, str):
|
||||
item = item.strip(" ")
|
||||
if is_image(item):
|
||||
image = self.image_processor.fetch_images(item)
|
||||
full_text += image_tokens(last_was_image)
|
||||
image_objects.append(image)
|
||||
last_was_image = True
|
||||
else:
|
||||
# we add end_of_utterance_token between each subsequent text prompts (but not at the last one!)
|
||||
if add_end_of_utterance_token and last_was_text:
|
||||
full_text += end_of_utterance_token
|
||||
full_text += item
|
||||
last_was_image = False
|
||||
else:
|
||||
# must be an image obj
|
||||
full_text += image_tokens(last_was_image)
|
||||
image_objects.append(item)
|
||||
last_was_image = True
|
||||
|
||||
if add_eos_token:
|
||||
full_text += self.tokenizer.eos_token
|
||||
|
||||
if debug is True:
|
||||
print(f"{full_text=}")
|
||||
|
||||
image_objects = self.image_processor(image_objects, transform=transform)
|
||||
|
||||
text_encoding = self.tokenizer(
|
||||
text=full_text,
|
||||
add_special_tokens=False,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
all_texts.append(text_encoding["input_ids"])
|
||||
all_images.append(image_objects)
|
||||
|
||||
max_seq_len = max(len(x) for x in all_texts)
|
||||
|
||||
# max_num_images has to be at least 1 even when there are no images
|
||||
max_num_images = max(len(x) for x in all_images)
|
||||
max_num_images = max(1, max_num_images)
|
||||
|
||||
at_least_one_image = sum(len(x) for x in all_images) > 0
|
||||
output_input_ids = []
|
||||
output_images = []
|
||||
output_attention_masks = []
|
||||
for text, images in zip(all_texts, all_images):
|
||||
padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len
|
||||
unpadded_seq_len = len(text)
|
||||
start = max_seq_len - unpadded_seq_len
|
||||
padded_input_ids[start:] = text[:max_seq_len]
|
||||
|
||||
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
|
||||
attention_mask[start:] = 1
|
||||
|
||||
image_count = padded_input_ids.count(self.image_token_id)
|
||||
local_max_num_images = min(image_count, max_num_images)
|
||||
|
||||
current_images = images[:local_max_num_images]
|
||||
|
||||
if len(current_images) > 0:
|
||||
padded_image_tensor = torch.zeros(
|
||||
max_num_images, *current_images.size()[1:]
|
||||
)
|
||||
padded_image_tensor[: current_images.size(0)] = current_images
|
||||
else:
|
||||
padded_image_tensor = torch.zeros(
|
||||
max_num_images, *self.default_image_dims
|
||||
)
|
||||
|
||||
output_images.append(padded_image_tensor)
|
||||
output_input_ids.append(torch.tensor(padded_input_ids))
|
||||
|
||||
output_attention_masks.append(attention_mask)
|
||||
|
||||
output_input_ids = torch.stack(output_input_ids)
|
||||
output_images = torch.stack(output_images)
|
||||
output_attention_masks = torch.stack(output_attention_masks)
|
||||
|
||||
if at_least_one_image:
|
||||
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
|
||||
output_input_ids, self.tokenizer
|
||||
)
|
||||
image_attention_mask = incremental_to_binary_attention_mask(
|
||||
image_attention_mask, num_classes=max_num_images
|
||||
)
|
||||
else:
|
||||
# in full language mode we set the image mask to all-0s
|
||||
image_attention_mask = torch.zeros(
|
||||
output_input_ids.shape[0],
|
||||
output_input_ids.shape[1],
|
||||
1,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"input_ids": output_input_ids,
|
||||
"attention_mask": output_attention_masks,
|
||||
"pixel_values": output_images,
|
||||
"image_attention_mask": image_attention_mask,
|
||||
}
|
||||
)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
@ -1,529 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
logging,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdeficsVisionModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
||||
The image embeddings obtained by applying the projection layer to the pooler_output.
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
|
||||
class IdeficsVisionEmbeddings(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.class_embedding")
|
||||
)
|
||||
|
||||
self.patch_embedding = nn.Conv2d.load_no_bias(
|
||||
prefix=f"{prefix}.patch_embedding",
|
||||
weights=weights,
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix="model.vision_model.embeddings.position_embedding", weights=weights
|
||||
)
|
||||
self.position_ids = (
|
||||
torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype)
|
||||
) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision
|
||||
class IdeficsVisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
|
||||
)
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return (
|
||||
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ causal_attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_weights = attn_weights_reshaped.view(
|
||||
bsz * self.num_heads, tgt_len, src_len
|
||||
)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
|
||||
class IdeficsVisionMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.fc1", weights=weights, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.fc2", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision
|
||||
class IdeficsVisionEncoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = IdeficsVisionAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = IdeficsVisionMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision
|
||||
class IdeficsVisionEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`IdeficsVisionEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: IdeficsVisionConfig
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
IdeficsVisionEncoderLayer(
|
||||
prefix=f"{prefix}.encoder.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
# self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
# if self.gradient_checkpointing and self.training:
|
||||
|
||||
# def create_custom_forward(module):
|
||||
# def custom_forward(*inputs):
|
||||
# return module(*inputs, output_attentions)
|
||||
|
||||
# return custom_forward
|
||||
|
||||
# layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
# create_custom_forward(encoder_layer),
|
||||
# hidden_states,
|
||||
# attention_mask,
|
||||
# causal_attention_mask,
|
||||
# )
|
||||
# else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, encoder_states, all_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
|
||||
class IdeficsVisionTransformer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embeddings = IdeficsVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
)
|
||||
self.pre_layrnorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.encoder = IdeficsVisionEncoder(
|
||||
prefix=prefix, config=config, weights=weights
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.post_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
Loading…
Reference in New Issue
Block a user