From 719907410b0fc71fad4ca7c36f76b2cf4ad7cb64 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 23 Jun 2025 17:15:39 +0800 Subject: [PATCH] [gaudi] Refine rope memory, do not need to keep sin/cos cache per layer (#3274) --- .../text_generation_server/layers/rotary.py | 17 +- .../custom_modeling/flash_cohere_modeling.py | 22 +- .../custom_modeling/flash_dbrx_modeling.py | 28 +- .../flash_deepseek_v2_modeling.py | 19 +- .../flash_deepseek_v3_modeling.py | 19 +- .../custom_modeling/flash_gemma2_modeling.py | 35 +- .../custom_modeling/flash_gemma3_modeling.py | 50 +- .../custom_modeling/flash_gemma_modeling.py | 27 +- .../custom_modeling/flash_gptj_modeling.py | 23 +- .../custom_modeling/flash_llama4_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 22 +- .../custom_modeling/flash_mistral_modeling.py | 26 +- .../custom_modeling/flash_mixtral_modeling.py | 23 +- .../custom_modeling/flash_neox_modeling.py | 29 +- .../custom_modeling/flash_phi_modeling.py | 27 +- .../custom_modeling/flash_qwen2_modeling.py | 25 +- .../custom_modeling/flash_qwen3_modeling.py | 22 +- .../flash_qwen3_moe_modeling.py | 24 +- .../custom_modeling/flash_rw_modeling.py | 31 +- .../flash_starcoder2_modeling.py | 24 +- .../models/custom_modeling/idefics_config.py | 326 ---- .../idefics_image_processing.py | 297 ---- .../custom_modeling/idefics_modeling.py | 1474 ----------------- .../custom_modeling/idefics_perceiver.py | 276 --- .../custom_modeling/idefics_processing.py | 443 ----- .../models/custom_modeling/idefics_vision.py | 529 ------ 26 files changed, 315 insertions(+), 3525 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index d381d4c6..7e740e5f 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -36,7 +36,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None - self.max_position_embeddings = max_position_embeddings + self._update_cos_sin_cache( + torch.float32, inv_freq.device, max_position_embeddings + ) def forward( self, @@ -268,9 +270,7 @@ class PositionRotaryEmbedding(nn.Module): self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin(self, position_ids: torch.Tensor): - self._update_cos_sin_cache( - torch.float32, position_ids.device, seqlen=self.max_position_embeddings - ) + cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -298,6 +298,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -351,6 +354,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( @@ -592,9 +598,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): position_ids: torch.Tensor, ): slen = position_ids.shape[0] - self._update_cos_sin_cache( - torch.float32, position_ids.device, seqlen=self.max_position_embeddings - ) cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 7a32a85c..367c26c9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -160,18 +160,14 @@ class FlashCohereAttention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = CohereRotary.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size**-0.5 @@ -325,11 +321,14 @@ class CohereMLP(nn.Module): class FlashCohereLayer(nn.Module): - def __init__(self, prefix: str, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashCohereAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + rotary_emb=rotary_emb, ) self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -385,6 +384,12 @@ class FlashCohereModel(torch.nn.Module): self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) + rotary_emb = CohereRotary.static( + config=config, + dim=config.hidden_size // config.num_attention_heads, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ FlashCohereLayer( @@ -392,6 +397,7 @@ class FlashCohereModel(torch.nn.Module): layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 42af7798..c097f71e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -263,6 +263,7 @@ class DbrxAttention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.clip_qkv = config.attn_config.clip_qkv @@ -270,12 +271,7 @@ class DbrxAttention(torch.nn.Module): self.hidden_size = config.d_model self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.attn_config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size**-0.5 @@ -370,13 +366,17 @@ class DbrxNormAttentionNorm(nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.norm_1 = FastLayerNorm.load_no_bias( prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5 ) self.self_attn = DbrxAttention( - prefix=f"{prefix}.attn", config=config, weights=weights + prefix=f"{prefix}.attn", + config=config, + weights=weights, + rotary_emb=rotary_emb, ) self.norm_2 = FastLayerNorm.load_no_bias( prefix=f"{prefix}.norm_2", @@ -601,12 +601,15 @@ class DenseMoE(nn.Module): class DbrxLayer(nn.Module): - def __init__(self, prefix: str, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.blocks.{layer_id}" self.attn = DbrxNormAttentionNorm( - prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights + prefix=f"{prefix}.norm_attn_norm", + config=config, + weights=weights, + rotary_emb=rotary_emb, ) moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE @@ -649,6 +652,12 @@ class DbrxModel(torch.nn.Module): self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.wte", weights=weights ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.d_model // config.n_heads, + base=config.attn_config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ @@ -657,6 +666,7 @@ class DbrxModel(torch.nn.Module): layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.n_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 8e9002a2..08b7d99d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -156,6 +156,7 @@ class DeepseekV2Attention(torch.nn.Module): prefix: str, config, weights: Weights, + rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads @@ -167,13 +168,7 @@ class DeepseekV2Attention(torch.nn.Module): self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim self.value_head_size = config.v_head_dim self.head_pad_size = max(self.head_size, self.value_head_size) - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.qk_rope_head_dim, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb mscale = get_mscale( self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim @@ -459,7 +454,7 @@ class DeepseekV2MoE(nn.Module): class DeepseekV2Layer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.layers.{layer_id}" @@ -467,6 +462,7 @@ class DeepseekV2Layer(nn.Module): prefix=f"{prefix}.self_attn", config=config, weights=weights, + rotary_emb=rotary_emb, ) if ( @@ -541,6 +537,12 @@ class DeepseekV2Model(torch.nn.Module): prefix=f"{prefix}.embed_tokens", weights=weights ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ DeepseekV2Layer( @@ -548,6 +550,7 @@ class DeepseekV2Model(torch.nn.Module): layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 8e058093..3a6a974a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -169,6 +169,7 @@ class DeepseekV3Attention(torch.nn.Module): prefix: str, config, weights: Weights, + rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads @@ -180,13 +181,7 @@ class DeepseekV3Attention(torch.nn.Module): self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim self.value_head_size = config.v_head_dim self.head_pad_size = max(self.head_size, self.value_head_size) - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.qk_rope_head_dim, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb mscale = get_mscale( self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim @@ -535,7 +530,7 @@ class DeepseekV3MoE(nn.Module): class DeepseekV3Layer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.layers.{layer_id}" @@ -543,6 +538,7 @@ class DeepseekV3Layer(nn.Module): prefix=f"{prefix}.self_attn", config=config, weights=weights, + rotary_emb=rotary_emb, ) if ( @@ -616,6 +612,12 @@ class DeepseekV3Model(torch.nn.Module): self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ @@ -624,6 +626,7 @@ class DeepseekV3Model(torch.nn.Module): layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a1a20999..74d9397e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -166,7 +166,14 @@ def _load_gqa(config, prefix: str, weights): class FlashGemma2Attention(torch.nn.Module): def __init__( - self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + self, + prefix: str, + config, + weights, + layer_id, + causal: bool, + is_sliding: bool, + rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads @@ -176,13 +183,7 @@ class FlashGemma2Attention(torch.nn.Module): self.window_size = config.sliding_window else: self.window_size = -1 - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb # self.softmax_scale = self.head_size**-0.5 self.softmax_scale = config.query_pre_attn_scalar**-0.5 @@ -354,7 +355,14 @@ class Gemma2MLP(nn.Module): class FlashGemma2Layer(nn.Module): def __init__( - self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + self, + prefix: str, + config, + weights, + layer_id, + causal: bool, + is_sliding: bool, + rotary_emb, ): super().__init__() self.self_attn = FlashGemma2Attention( @@ -364,6 +372,7 @@ class FlashGemma2Layer(nn.Module): layer_id=layer_id, causal=causal, is_sliding=is_sliding, + rotary_emb=rotary_emb, ) self.mlp = Gemma2MLP( prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id @@ -435,6 +444,13 @@ class FlashGemma2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.head_dim, + base=config.rope_theta, + device=weights.device, + ) + self.layers = nn.ModuleList( [ FlashGemma2Layer( @@ -444,6 +460,7 @@ class FlashGemma2Model(torch.nn.Module): layer_id=layer_id, causal=causal, is_sliding=layer_id % 2 == 0, + rotary_emb=rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 92f059bc..7b789d30 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -119,7 +119,15 @@ def _load_gqa(config, prefix: str, weights): class FlashGemma3Attention(torch.nn.Module): def __init__( - self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + self, + prefix: str, + config, + weights, + layer_id, + causal: bool, + is_sliding: bool, + local_rotary_emb, + global_rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads @@ -130,20 +138,10 @@ class FlashGemma3Attention(torch.nn.Module): # TODO: remove this hack to support local sliding window config = copy.deepcopy(config) config.rope_scaling = dict(rope_type="default") - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=config.head_dim, - base=config.rope_local_base_freq, - device=weights.device, - ) + self.rotary_emb = local_rotary_emb else: self.window_size = -1 - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=config.head_dim, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = global_rotary_emb self.softmax_scale = ( config.query_pre_attn_scalar**-0.5 @@ -336,7 +334,15 @@ class Gemma3MLP(nn.Module): class FlashGemma3Layer(nn.Module): def __init__( - self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + self, + prefix: str, + config, + weights, + layer_id, + causal: bool, + is_sliding: bool, + local_rotary_emb, + global_rotary_emb, ): super().__init__() self.self_attn = FlashGemma3Attention( @@ -346,6 +352,8 @@ class FlashGemma3Layer(nn.Module): layer_id=layer_id, causal=causal, is_sliding=is_sliding, + local_rotary_emb=local_rotary_emb, + global_rotary_emb=global_rotary_emb, ) self.mlp = Gemma3MLP( prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id @@ -417,6 +425,18 @@ class FlashGemma3Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + local_rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.head_dim, + base=config.rope_local_base_freq, + device=weights.device, + ) + global_rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.head_dim, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ @@ -427,6 +447,8 @@ class FlashGemma3Model(torch.nn.Module): layer_id=layer_id, causal=causal, is_sliding=bool((layer_id + 1) % config.sliding_window_pattern), + local_rotary_emb=local_rotary_emb, + global_rotary_emb=global_rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 7a2ec22e..5d6dc67c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -163,19 +163,12 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim self.causal = causal - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) - + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: @@ -300,10 +293,14 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb): super().__init__() self.self_attn = FlashGemmaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + causal=causal, + rotary_emb=rotary_emb, ) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -359,6 +356,13 @@ class FlashGemmaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.head_dim, + base=config.rope_theta, + device=weights.device, + ) + self.layers = nn.ModuleList( [ FlashGemmaLayer( @@ -366,6 +370,7 @@ class FlashGemmaModel(torch.nn.Module): config=config, weights=weights, causal=causal, + rotary_emb=rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 679380a1..1e7a867c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -110,6 +110,7 @@ class FlashGPTJAttention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads @@ -143,13 +144,7 @@ class FlashGPTJAttention(torch.nn.Module): self.kv_head_mapping = torch.arange( 0, self.num_heads, dtype=torch.int32, device=weights.device ) - - self.rotary_emb = GPTJRotary.static( - config=config, - dim=self.rotary_dim, - base=10000, - device=weights.device, - ) + self.rotary_emb = rotary_emb def forward( self, @@ -244,10 +239,13 @@ class GPTJMLP(nn.Module): class FlashGPTJLayer(nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, rotary_emb): super().__init__() self.self_attn = FlashGPTJAttention( - prefix=f"{prefix}.attn", config=config, weights=weights + prefix=f"{prefix}.attn", + config=config, + weights=weights, + rotary_emb=rotary_emb, ) self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -291,6 +289,12 @@ class FlashGPTJModel(torch.nn.Module): self.config = config self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights) + rotary_emb = GPTJRotary.static( + config=config, + dim=config.rotary_dim, + base=10000, + device=weights.device, + ) self.layers = nn.ModuleList( [ FlashGPTJLayer( @@ -299,6 +303,7 @@ class FlashGPTJModel(torch.nn.Module): ), config=config, weights=weights, + rotary_emb=rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 3b30f2e0..1db8ad10 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -303,7 +303,7 @@ class Llama4TextAttention(FlashLlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, prefix, config, weights, layer_idx): - super().__init__(layer_idx, prefix, config, weights) + super().__init__(layer_idx, prefix, config, weights, None) self.config = config self.layer_idx = layer_idx self.head_dim = getattr( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 70fcc824..fbfcd39c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -133,6 +133,7 @@ class FlashLlamaAttention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads @@ -145,13 +146,7 @@ class FlashLlamaAttention(torch.nn.Module): config, "num_key_value_heads", config.num_attention_heads ) - if config.model_type != "llama4_text": - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr( @@ -376,7 +371,7 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): - def __init__(self, index, prefix, config, weights): + def __init__(self, index, prefix, config, weights, rotary_emb): super().__init__() with no_fp8(weights): @@ -385,6 +380,7 @@ class FlashLlamaLayer(nn.Module): prefix=f"{prefix}.self_attn", config=config, weights=weights, + rotary_emb=rotary_emb, ) if config.model_type == "phimoe": @@ -480,6 +476,13 @@ class FlashLlamaModel(torch.nn.Module): # Skip fp8 quant for first and last layers self.layers = nn.ModuleList() self.cross_attention_layers = getattr(config, "cross_attention_layers", []) + + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.hidden_size // config.num_attention_heads, + base=config.rope_theta, + device=weights.device, + ) with no_fp8(weights): self.layers.append( FlashLlamaLayer( @@ -487,6 +490,7 @@ class FlashLlamaModel(torch.nn.Module): prefix=f"{prefix}.layers.0", config=config, weights=weights, + rotary_emb=rotary_emb, ) ) @@ -512,6 +516,7 @@ class FlashLlamaModel(torch.nn.Module): prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, + rotary_emb=rotary_emb, ) ) @@ -523,6 +528,7 @@ class FlashLlamaModel(torch.nn.Module): prefix=(f"{prefix}.layers.{last_layer_id}"), config=config, weights=weights, + rotary_emb=rotary_emb, ) ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 008df32d..f7aed118 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -104,7 +104,7 @@ class MistralConfig(PretrainedConfig): class MistralAttention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id, rotary_emb): super().__init__() self.max_past = ( config.sliding_window if config.sliding_window is not None else -1 @@ -117,12 +117,7 @@ class MistralAttention(torch.nn.Module): else: self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size**-0.5 @@ -300,13 +295,14 @@ class MistralMLP(nn.Module): class MistralLayer(nn.Module): - def __init__(self, prefix: str, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id, rotary_emb): super().__init__() self.self_attn = MistralAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + rotary_emb=rotary_emb, ) self.mlp = MistralMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id @@ -366,6 +362,19 @@ class MistralModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + + if getattr(config, "head_dim", None) is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=head_dim, + base=config.rope_theta, + device=weights.device, + ) + self.layers = nn.ModuleList( [ MistralLayer( @@ -373,6 +382,7 @@ class MistralModel(torch.nn.Module): config=config, weights=weights, layer_id=layer_id, + rotary_emb=rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 4993b444..8c682c7f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -188,6 +188,7 @@ class MixtralAttention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.max_past = ( @@ -196,13 +197,7 @@ class MixtralAttention(torch.nn.Module): self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size**-0.5 @@ -345,12 +340,15 @@ class MixtralMoE(nn.Module): class MixtralLayer(nn.Module): - def __init__(self, prefix: str, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.layers.{layer_id}" self.self_attn = MixtralAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + rotary_emb=rotary_emb, ) moe_layer_cls = ( @@ -416,6 +414,12 @@ class MixtralModel(torch.nn.Module): weights=weights, ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.hidden_size // config.num_attention_heads, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ MixtralLayer( @@ -423,6 +427,7 @@ class MixtralModel(torch.nn.Module): layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 6e1050b6..8ee1dfa2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -99,7 +99,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): class FlashNeoxAttention(torch.nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix, weights, rotary_emb): super().__init__() num_heads = config.num_attention_heads hidden_size = config.hidden_size @@ -116,14 +116,7 @@ class FlashNeoxAttention(torch.nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.rotary_dim, - base=config.rotary_emb_base, - device=weights.device, - ) - + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv( @@ -231,7 +224,7 @@ class FlashMLP(nn.Module): class FlashNeoXLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, config, weights, rotary_emb): super().__init__() layer_norm_eps = config.layer_norm_eps @@ -248,7 +241,10 @@ class FlashNeoXLayer(nn.Module): eps=layer_norm_eps, ) self.attention = FlashNeoxAttention( - config, prefix=f"{prefix}.attention", weights=weights + config, + prefix=f"{prefix}.attention", + weights=weights, + rotary_emb=rotary_emb, ) self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) @@ -328,9 +324,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): prefix=f"{prefix}.embed_in", weights=weights ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=int( + config.rotary_pct * (config.hidden_size // config.num_attention_heads) + ), + base=config.rotary_emb_base, + device=weights.device, + ) + self.layers = nn.ModuleList( [ - FlashNeoXLayer(layer_id, config, weights) + FlashNeoXLayer(layer_id, config, weights, rotary_emb) for layer_id in range(config.num_hidden_layers) ] ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 78aaf0d5..d7fc844b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -113,6 +113,7 @@ class FlashPhiAttention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.num_heads = config.num_attention_heads @@ -121,13 +122,7 @@ class FlashPhiAttention(torch.nn.Module): self.softmax_scale = self.head_size**-0.5 self.rotary_dim = int(config.partial_rotary_factor * self.head_size) - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.rotary_dim, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -259,11 +254,14 @@ class PhiMLP(nn.Module): class FlashPhiLayer(nn.Module): - def __init__(self, prefix: str, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashPhiAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + rotary_emb=rotary_emb, ) self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastLayerNorm.load( @@ -315,6 +313,16 @@ class FlashPhiModel(torch.nn.Module): self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=int( + config.partial_rotary_factor + * (config.hidden_size // config.num_attention_heads) + ), + base=config.rope_theta, + device=weights.device, + ) + self.layers = nn.ModuleList( [ FlashPhiLayer( @@ -322,6 +330,7 @@ class FlashPhiModel(torch.nn.Module): layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index ac31e53b..13e1c916 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -58,6 +58,7 @@ class Qwen2Attention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.max_past = ( @@ -66,13 +67,7 @@ class Qwen2Attention(torch.nn.Module): self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size**-0.5 @@ -199,11 +194,14 @@ class Qwen2MLP(nn.Module): class Qwen2Layer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.layers.{layer_id}" self.self_attn = Qwen2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + rotary_emb=rotary_emb, ) self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load( @@ -258,6 +256,14 @@ class Qwen2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.hidden_size // config.num_attention_heads, + base=config.rope_theta, + device=weights.device, + ) + self.layers = nn.ModuleList( [ Qwen2Layer( @@ -265,6 +271,7 @@ class Qwen2Model(torch.nn.Module): layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 8bd00c13..63ee4c97 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -41,7 +41,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config, prefix, weights, layer_idx): + def __init__(self, config, prefix, weights, layer_idx, rotary_emb): super().__init__() self.config = config self.layer_idx = layer_idx @@ -54,12 +54,7 @@ class Qwen3Attention(nn.Module): self.num_heads = config.num_attention_heads self.attention_dropout = config.attention_dropout self.softmax_scale = self.head_dim**-0.5 - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_dim, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -179,7 +174,7 @@ class Qwen3Attention(nn.Module): class Qwen3DecoderLayer(nn.Module): - def __init__(self, config, prefix, weights, layer_idx: int): + def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen3Attention( @@ -187,6 +182,7 @@ class Qwen3DecoderLayer(nn.Module): prefix=f"{prefix}.self_attn", weights=weights, layer_idx=layer_idx, + rotary_emb=rotary_emb, ) self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights) self.input_layernorm = FastRMSNorm.load( @@ -241,6 +237,15 @@ class Qwen3Model(nn.Module): self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=head_dim, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ @@ -249,6 +254,7 @@ class Qwen3Model(nn.Module): prefix=f"{prefix}.layers.{layer_idx}", weights=weights, layer_idx=layer_idx, + rotary_emb=rotary_emb, ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py index 5e4bc7fa..da474adc 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py @@ -80,7 +80,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class Qwen3MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config, prefix, weights, layer_idx): + def __init__(self, config, prefix, weights, layer_idx, rotary_emb): super().__init__() self.config = config self.layer_idx = layer_idx @@ -108,13 +108,7 @@ class Qwen3MoeAttention(nn.Module): self.o_proj = FastLinear.load( config, f"{prefix}.o_proj", weights, bias=config.attention_bias ) - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_dim, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb self.q_norm = FastRMSNorm.load( prefix=f"{prefix}.q_norm", @@ -345,7 +339,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): class Qwen3MoeDecoderLayer(nn.Module): - def __init__(self, config, prefix, weights, layer_idx: int): + def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb): super().__init__() self.hidden_size = config.hidden_size @@ -355,6 +349,7 @@ class Qwen3MoeDecoderLayer(nn.Module): prefix=f"{prefix}.self_attn", weights=weights, layer_idx=layer_idx, + rotary_emb=rotary_emb, ) else: self.self_attn = Qwen3MoeAttention( @@ -362,6 +357,7 @@ class Qwen3MoeDecoderLayer(nn.Module): prefix=f"{prefix}.self_attn", weights=weights, layer_idx=layer_idx, + rotary_emb=rotary_emb, ) moe_layer_cls = ( @@ -433,6 +429,15 @@ class Qwen3MoeModel(nn.Module): self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=head_dim, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ @@ -441,6 +446,7 @@ class Qwen3MoeModel(nn.Module): prefix=f"{prefix}.layers.{layer_idx}", weights=weights, layer_idx=layer_idx, + rotary_emb=rotary_emb, ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 06616f85..bd8397f1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -134,6 +134,7 @@ class FlashRWAttention(torch.nn.Module): config, prefix: str, weights, + rotary_emb, ): super().__init__() self.num_heads = config.n_head @@ -141,13 +142,8 @@ class FlashRWAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads self.rope_theta = config.rope_theta + self.rotary_emb = rotary_emb - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=self.rope_theta, - device=weights.device, - ) self.softmax_scale = self.head_size ** (-0.5) if self.num_heads % weights.process_group.size() != 0: @@ -243,6 +239,7 @@ class FlashRWLargeAttention(torch.nn.Module): config, prefix: str, weights, + rotary_emb, ): super().__init__() @@ -255,13 +252,8 @@ class FlashRWLargeAttention(torch.nn.Module): self.head_size = hidden_size // num_heads self.num_groups = num_groups self.rope_theta = config.rope_theta + self.rotary_emb = rotary_emb - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=self.rope_theta, - device=weights.device, - ) self.softmax_scale = self.head_size ** (-0.5) # self.num_groups = num_heads // (num_heads_kv * 2) @@ -382,6 +374,7 @@ class FlashRWLayer(nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() @@ -404,6 +397,7 @@ class FlashRWLayer(nn.Module): config, prefix=f"{prefix}.self_attention", weights=weights, + rotary_emb=rotary_emb, ) self.post_attention_layernorm = ( FastLayerNorm.load( @@ -526,7 +520,7 @@ class FlashRWLayerNorm(nn.Module): class FlashRWLargeLayer(nn.Module): - def __init__(self, layer_id, prefix: str, config, weights): + def __init__(self, layer_id, prefix: str, config, weights, rotary_emb): super().__init__() prefix = f"{prefix}.h.{layer_id}" @@ -536,6 +530,7 @@ class FlashRWLargeLayer(nn.Module): config, prefix=f"{prefix}.self_attention", weights=weights, + rotary_emb=rotary_emb, ) assert config.parallel_attn, "This version doesn't support non parallel_attn" @@ -593,11 +588,17 @@ class FlashRWModel(FlashRWPreTrainedModel): self.word_embeddings = TensorParallelEmbedding( prefix=f"{prefix}.word_embeddings", weights=weights ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.hidden_size // config.n_head, + base=config.rope_theta, + device=weights.device, + ) if config.new_decoder_architecture: self.h = nn.ModuleList( [ - FlashRWLargeLayer(layer_id, prefix, config, weights) + FlashRWLargeLayer(layer_id, prefix, config, weights, rotary_emb) for layer_id in range(config.num_hidden_layers) ] ) @@ -605,7 +606,7 @@ class FlashRWModel(FlashRWPreTrainedModel): else: self.h = nn.ModuleList( [ - FlashRWLayer(layer_id, prefix, config, weights) + FlashRWLayer(layer_id, prefix, config, weights, rotary_emb) for layer_id in range(config.num_hidden_layers) ] ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 1a749595..45baf4db 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -180,6 +180,7 @@ class Starcoder2Attention(torch.nn.Module): prefix: str, config, weights, + rotary_emb, ): super().__init__() self.max_past = ( @@ -188,13 +189,7 @@ class Starcoder2Attention(torch.nn.Module): self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = rotary_emb self.softmax_scale = self.head_size**-0.5 @@ -411,11 +406,15 @@ STARCODER2_MLP_CLASSES = { class Starcoder2Layer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, config, weights, rotary_emb): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = Starcoder2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + index=layer_id, + rotary_emb=rotary_emb, ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( @@ -481,12 +480,19 @@ class Starcoder2Model(torch.nn.Module): self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) + rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=config.hidden_size // config.num_attention_heads, + base=config.rope_theta, + device=weights.device, + ) self.layers = nn.ModuleList( [ Starcoder2Layer( layer_id, config, weights, + rotary_emb, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py deleted file mode 100644 index 6ce2054e..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py +++ /dev/null @@ -1,326 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Idefics model configuration""" -import copy - -from transformers import PretrainedConfig - -IDEFICS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "HuggingFaceM4/idefics-9b": "https://huggingface.co/HuggingFaceM4/idefics-9b/blob/main/config.json", - "HuggingFaceM4/idefics-80b": "https://huggingface.co/HuggingFaceM4/idefics-80b/blob/main/config.json", -} - - -class IdeficsVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an - Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Idefics-9B. - e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - hidden_size (`int`, *optional*, defaults to 768): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`) - image_size (`int`, *optional*, defaults to 224): - The size (resolution) of each image. - intermediate_size (`int`, *optional*, defaults to 5120): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - patch_size (`int`, *optional*, defaults to 14): - The size (resolution) of each patch. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer encoder. - image_num_channels (`int`, *optional*, defaults to `3`): - Number of image channels. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. - layer_norm_eps (`float`, *optional*, defaults to 1e-5): - The epsilon used by the layer normalization layers. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - initializer_factor (`float`, *optional*, defaults to 1.0): - A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization - testing). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - """ - - model_type = "idefics" - attribute_map = { - "hidden_size": "embed_dim", - } - - def __init__( - self, - embed_dim=768, - image_size=224, - intermediate_size=5120, - patch_size=14, - num_hidden_layers=32, - num_attention_heads=16, - num_channels=3, - hidden_act="gelu", - layer_norm_eps=1e-5, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - **kwargs, - ): - self.embed_dim = embed_dim - self.image_size = image_size - self.intermediate_size = intermediate_size - self.patch_size = patch_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.layer_norm_eps = layer_norm_eps - self.attention_dropout = attention_dropout - self.initializer_range = initializer_range - self.initializer_factor = initializer_factor - self.hidden_act = hidden_act - - super().__init__(**kwargs) - - -class IdeficsPerceiverConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an - Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Idefics-9B. - e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - use_resampler (`bool`, *optional*, defaults to `False`): - Whether or not to use the resampler - resampler_n_latents (`int`, *optional*, defaults to ): - Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). - resampler_depth (`int`, *optional*, defaults to 6): - Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). - resampler_n_heads (`int`, *optional*, defaults to 16): - Number of heads in each Transformer block (for multi-headed self-attention). - resampler_head_dim (`int`, *optional*, defaults to 96): - Dimensionality of each head projection in the Transformer block. - qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`): - Whether or not to use qk layer norms in perceiver - """ - - model_type = "idefics" - - def __init__( - self, - use_resampler=False, - resampler_n_latents=64, - resampler_depth=6, - resampler_n_heads=16, - resampler_head_dim=96, - qk_layer_norms_perceiver=False, - **kwargs, - ): - self.use_resampler = use_resampler - self.resampler_n_latents = resampler_n_latents - self.resampler_depth = resampler_depth - self.resampler_n_heads = resampler_n_heads - self.resampler_head_dim = resampler_head_dim - self.qk_layer_norms_perceiver = qk_layer_norms_perceiver - - super().__init__(**kwargs) - - -class IdeficsConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an - Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Idefics-9B. - e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - additional_vocab_size (`int`, *optional`, defaults to 0): - Additional vocabulary size of the model, typically for the special "" token. Additional vocab tokens - are always trainable whereas regular vocab tokens can be frozen or not. - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`~IdeficsModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - alpha_initializer (`str`, *optional*, defaults to `"zeros"`): - Initialization type for the alphas. - alphas_initializer_range (`float`, *optional*, defaults to 0.0): - The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross - Attention. - alpha_type (`str`, *optional*, defaults to `"float"`): - Whether the gating alphas should be vectors or single floats. - rms_norm_eps (`float`, *optional*, defaults to 1e-6): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 0) - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1) - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2) - End of stream token id. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - cross_layer_interval (`int`, *optional*, default to 1) - Interval for cross attention (from text to image) layers. - qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k - freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers - freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`): - Exceptions to freezing text layers when `freeze_text_layers` is `True` - freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head - freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers - freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`): - Exceptions to freezing vision layers when `freeze_vision_layers` is `True` - use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler - vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict - perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict - Example: - ```python - >>> from transformers import IdeficsModel, IdeficsConfig - >>> # Initializing a Idefics idefics-9b style configuration - >>> configuration = IdeficsConfig() - >>> # Initializing a model from the idefics-9b style configuration - >>> model = IdeficsModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "idefics" - is_composition = True - - def __init__( - self, - vocab_size=32000, - additional_vocab_size=0, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - dropout=0.0, - hidden_act="silu", - initializer_range=0.02, - alpha_initializer="zeros", - alphas_initializer_range=0.0, - alpha_type="float", - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - cross_layer_interval=1, - qk_layer_norms=False, - freeze_text_layers=True, - freeze_text_module_exceptions=[], - freeze_lm_head=False, - freeze_vision_layers=True, - freeze_vision_module_exceptions=[], - use_resampler=False, - vision_config=None, - perceiver_config=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.additional_vocab_size = additional_vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.dropout = dropout - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.alpha_initializer = alpha_initializer - self.alphas_initializer_range = alphas_initializer_range - self.alpha_type = alpha_type - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - - self.cross_layer_interval = cross_layer_interval - self.qk_layer_norms = qk_layer_norms - self.freeze_vision_layers = freeze_vision_layers - - self.freeze_text_layers = freeze_text_layers - self.freeze_text_module_exceptions = freeze_text_module_exceptions - self.freeze_vision_module_exceptions = freeze_vision_module_exceptions - self.freeze_lm_head = freeze_lm_head - - self.use_resampler = use_resampler - - if perceiver_config is None: - self.perceiver_config = IdeficsPerceiverConfig() - elif isinstance(perceiver_config, dict): - self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config) - elif isinstance(perceiver_config, IdeficsPerceiverConfig): - self.perceiver_config = perceiver_config - - if vision_config is None: - self.vision_config = IdeficsVisionConfig() - elif isinstance(vision_config, dict): - self.vision_config = IdeficsVisionConfig(**vision_config) - elif isinstance(vision_config, IdeficsVisionConfig): - self.vision_config = vision_config - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since - # PretrainedConfig.from_dict first instantiates the class with the config dict and only then - # updates the config object with `kwargs` from from_pretrained, so during the instantiation - # of this object many attributes have default values and haven't yet been overridden. - # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run. - - def to_dict(self): - """ - Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. - Returns: - `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, - """ - output = copy.deepcopy(self.__dict__) - - output["vision_config"] = self.vision_config.to_dict() - output["perceiver_config"] = self.perceiver_config.to_dict() - output["model_type"] = self.__class__.model_type - - return output diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py deleted file mode 100644 index afb8e1f9..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py +++ /dev/null @@ -1,297 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Image processor class for Idefics.""" - -from typing import Callable, Dict, List, Optional, Union, Iterable -import numpy as np - -from PIL import Image - -import transformers -from transformers.image_processing_utils import BaseImageProcessor, BatchFeature -from transformers.image_transforms import ( - resize, - to_channel_dimension_format, - rescale, - normalize, -) -from transformers.image_utils import ( - ChannelDimension, - ImageInput, - PILImageResampling, - make_list_of_images, - to_numpy_array, - valid_images, -) -from io import BytesIO -import base64 -import requests -from transformers import TensorType, is_torch_available - - -IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073] -IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711] - - -def convert_to_rgb(image): - # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background - # for transparent images. The call to `alpha_composite` handles this case - if image.mode == "RGB": - return image - - image_rgba = image.convert("RGBA") - background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) - alpha_composite = Image.alpha_composite(background, image_rgba) - alpha_composite = alpha_composite.convert("RGB") - return alpha_composite - - -class IdeficsImageProcessor(BaseImageProcessor): - r""" - Constructs a Idefics image processor. - Args: - image_size (`int`, *optional*, defaults to `224`): - Resize to image size - image_num_channels (`int`, *optional*, defaults to `3`): - Number of image channels. - image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`): - Mean to use if normalizing the image. This is a float or list of floats the length of the number of - channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be - overridden by the `image_mean` parameter in the `preprocess` method. - image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`): - Standard deviation to use if normalizing the image. This is a float or list of floats the length of the - number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. - Can be overridden by the `image_std` parameter in the `preprocess` method. - """ - - model_input_names = ["pixel_values"] - - def __init__( - self, - image_size: int = 224, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - image_num_channels: Optional[int] = 3, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.image_size = image_size - self.image_num_channels = image_num_channels - self.image_mean = image_mean - self.image_std = image_std - - def preprocess( - self, - images: ImageInput, - image_num_channels: Optional[int] = 3, - image_size: Optional[Dict[str, int]] = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - transform: Callable = None, - **kwargs, - ) -> TensorType.PYTORCH: - """ - Preprocess a batch of images. - Args: - images (`ImageInput`): - A list of images to preprocess. - image_size (`int`, *optional*, defaults to `self.image_size`): - Resize to image size - image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`): - Number of image channels. - image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`): - Mean to use if normalizing the image. This is a float or list of floats the length of the number of - channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can - be overridden by the `image_mean` parameter in the `preprocess` method. - image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`): - Standard deviation to use if normalizing the image. This is a float or list of floats the length of the - number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` - method. Can be overridden by the `image_std` parameter in the `preprocess` method. - transform (`Callable`, *optional*, defaults to `None`): - A custom transform function that accepts a single image can be passed for training. For example, - `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is - assumed - and then a preset of inference-specific transforms will be applied to the images - Returns: - a PyTorch tensor of the processed images - """ - image_size = image_size if image_size is not None else self.image_size - image_num_channels = ( - image_num_channels - if image_num_channels is not None - else self.image_num_channels - ) - image_mean = image_mean if image_mean is not None else self.image_mean - image_std = image_std if image_std is not None else self.image_std - size = (image_size, image_size) - - if len(images) == 0: - return [] - - images = make_list_of_images(images) - - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - - # For training a user needs to pass their own set of transforms as a Callable. - # For reference this is what was used in the original IDEFICS training: - # transform = transforms.Compose([ - # convert_to_rgb, - # transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), - # transforms.ToTensor(), - # transforms.Normalize(mean=image_mean, std=image_std), - # ]) - if transform is not None: - if not is_torch_available(): - raise ImportError("To pass in `transform` torch must be installed") - import torch - - images = [transform(x) for x in images] - return torch.stack(images) - - # for inference we do the exact transforms that were used to train IDEFICS - images = [convert_to_rgb(x) for x in images] - # further transforms expect numpy arrays - images = [to_numpy_array(x) for x in images] - images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images] - images = [self.rescale(image=image, scale=1 / 255) for image in images] - images = [self.normalize(x, mean=image_mean, std=image_std) for x in images] - images = [ - to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images - ] - # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available - images = BatchFeature( - data={"pixel_values": images}, tensor_type=TensorType.PYTORCH - )["pixel_values"] - - return images - - def fetch_images(self, image_url_or_urls: Union[str, List[str]]): - """ - Convert a single or a list of urls into the corresponding `PIL.Image` objects. - If a single url is passed, the return value will be a single object. If a list is passed a list of objects is - returned. - """ - headers = { - "User-Agent": ( - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0" - " Safari/537.36" - ) - } - if isinstance(image_url_or_urls, list): - return [self.fetch_images(x) for x in image_url_or_urls] - elif isinstance(image_url_or_urls, str): - image = image_url_or_urls - - if image.startswith("http://") or image.startswith("https://"): - response = requests.get( - image_url_or_urls, stream=True, headers=headers, timeout=(1, 5) - ) - response.raise_for_status() - content = response.content - elif image.startswith("data:"): - # https://stackoverflow.com/questions/17090571/is-there-a-way-to-set-background-image-as-a-base64-encoded-image - #  - image = image.split(",")[-1] - content = base64.b64decode(image) - else: - raise ValueError(f"Unrecognized image {image}") - - try: - image = Image.open(BytesIO(content)) - # image.verify() - except Exception: - raise ValueError(f"Could not load image from url {image_url_or_urls}") - return image - else: - raise ValueError( - f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}" - ) - - def rescale( - self, - image: np.ndarray, - scale: float, - data_format: Optional[Union[str, ChannelDimension]] = None, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - **kwargs, - ) -> np.ndarray: - """ - Rescale an image by a scale factor. image = image * scale. - - Args: - image (`np.ndarray`): - Image to rescale. - scale (`float`): - The scaling factor to rescale pixel values by. - data_format (`str` or `ChannelDimension`, *optional*): - The channel dimension format for the output image. If unset, the channel dimension format of the input - image is used. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - input_data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format for the input image. If unset, the channel dimension format is inferred - from the input image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - Returns: - `np.ndarray`: The rescaled image. - """ - # return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs) - # requires 4.32 - return rescale(image, scale=scale, data_format=data_format, **kwargs) - - def normalize( - self, - image: np.ndarray, - mean: Union[float, Iterable[float]], - std: Union[float, Iterable[float]], - data_format: Optional[Union[str, ChannelDimension]] = None, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - **kwargs, - ) -> np.ndarray: - """ - Normalize an image. image = (image - image_mean) / image_std. - - Args: - image (`np.ndarray`): - Image to normalize. - mean (`float` or `Iterable[float]`): - Image mean to use for normalization. - std (`float` or `Iterable[float]`): - Image standard deviation to use for normalization. - data_format (`str` or `ChannelDimension`, *optional*): - The channel dimension format for the output image. If unset, the channel dimension format of the input - image is used. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - input_data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format for the input image. If unset, the channel dimension format is inferred - from the input image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - Returns: - `np.ndarray`: The normalized image. - """ - # TODO 4.32 - return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) - - -transformers.IdeficsImageProcessor = IdeficsImageProcessor diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py deleted file mode 100644 index 910e9bcd..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ /dev/null @@ -1,1474 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Idefics model.""" -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from transformers import PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - dataclass, -) -from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig -from text_generation_server.models.custom_modeling.idefics_vision import ( - IdeficsVisionTransformer, -) -from text_generation_server.models.custom_modeling.idefics_perceiver import ( - IdeficsPerceiverResampler, -) -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, - FastLinear, -) -from text_generation_server.layers.rotary import PositionRotaryEmbedding -from loguru import logger - -dropout_layer_norm = None - - -@dataclass -class BaseModelOutputWithPastImage(BaseModelOutputWithPast): - image_hidden_states: Optional[torch.FloatTensor] = None - - -@dataclass -class CausalLMOutputWithPastImage(CausalLMOutputWithPast): - image_hidden_states: Optional[torch.FloatTensor] = None - - -# logger = logging.get_logger(__name__) - -# _CONFIG_FOR_DOC = "IdeficsConfig" - -# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [ -# "HuggingFaceM4/idefics-9b", -# "HuggingFaceM4/idefics-80b", -# # See all Idefics models at https://huggingface.co/models?filter=idefics -# ] - - -def expand_inputs_for_generation( - input_ids, - expand_size=1, - is_encoder_decoder=False, - attention_mask=None, - encoder_outputs=None, - **model_kwargs, -): - expanded_return_idx = ( - torch.arange(input_ids.shape[0]) - .view(-1, 1) - .repeat(1, expand_size) - .view(-1) - .to(input_ids.device) - ) - input_ids = input_ids.index_select(0, expanded_return_idx) - - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = token_type_ids.index_select( - 0, expanded_return_idx - ) - - if attention_mask is not None: - model_kwargs["attention_mask"] = attention_mask.index_select( - 0, expanded_return_idx - ) - model_kwargs["image_attention_mask"] = model_kwargs[ - "image_attention_mask" - ].index_select(0, expanded_return_idx) - model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select( - 0, expanded_return_idx - ) - - if is_encoder_decoder: - if encoder_outputs is None: - raise ValueError( - "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." - ) - encoder_outputs["last_hidden_state"] = ( - encoder_outputs.last_hidden_state.index_select( - 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) - ) - ) - model_kwargs["encoder_outputs"] = encoder_outputs - return input_ids, model_kwargs - - -def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): - # must have this key set to at least None - model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) - - # update past - if "past_key_values" in outputs: - model_kwargs["past"] = outputs.past_key_values - elif "mems" in outputs: - model_kwargs["past"] = outputs.mems - elif "past_buckets_states" in outputs: - model_kwargs["past"] = outputs.past_buckets_states - else: - model_kwargs["past"] = None - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 - ) - - # update attention masks - if not is_encoder_decoder: - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], - dim=-1, - ) - if "image_attention_mask" in model_kwargs: - image_attention_mask = model_kwargs["image_attention_mask"] - last_mask = image_attention_mask[:, -1, :].unsqueeze(1) - model_kwargs["image_attention_mask"] = last_mask - - return model_kwargs - - -def prepare_inputs_for_generation(input_ids, past=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - - pixel_values = kwargs.get("pixel_values", None) - image_attention_mask = kwargs.get("image_attention_mask", None) - # if pixel_values is None or image_attention_mask is None: - # raise ValueError("pixel values and image attention mask cannot be None") - - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "pixel_values": pixel_values, - "image_attention_mask": image_attention_mask, - } - - -def freeze_model(model, module_exceptions=[]): - mapping = { - "LayerNorm": nn.LayerNorm, - "Linear": nn.Linear, - "Embedding": nn.Embedding, - } - module_exceptions_mapped = [mapping[m] for m in module_exceptions] - for module in model.modules(): - if module_exceptions and any( - [isinstance(module, t) for t in module_exceptions_mapped] - ): - module.requires_grad_( - True - ) # Explicitely setting it to true to avoid any mistakes - else: - module.requires_grad_(False) - return model - - -class IdeficsDecoupledPartialTPEmbedding(nn.Module): - def __init__( - self, - config, - weights, - ): - super().__init__() - self.num_embeddings = config.vocab_size - self.weight = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) - self.additional_weight = nn.Parameter( - weights.get_tensor("model.embed_tokens.additional_embedding.weight") - ) - - def forward(self, input_ids): - # Clone so that we don't modify the original input_ids later on - input_ids = input_ids.clone() - additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) - input_ids_additional_vocab = input_ids[additional_vocab_indices] - additional_embeddings = torch.nn.functional.embedding( - input_ids_additional_vocab - self.num_embeddings, self.additional_weight - ) - - # for successful lookup replace input_ids with 0, the results of these will be discarded anyway - input_ids[additional_vocab_indices] = 0 - full_vector = self.weight(input_ids) - - # overwrite the records with high indices - full_vector[additional_vocab_indices] = additional_embeddings - - return full_vector - - -class IdeficsDecoupledTensorParallelLinear(nn.Module): - # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear - """ - Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the - regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, - then it will create `out_additional_features * in_features` additional parameters that are always trained. If - `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. - """ - - def __init__( - self, - config, - weights, - ) -> None: - super().__init__() - self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights) - self.additional_fc = FastLinear.load( - config=config, - prefix="lm_head.additional_fc", - weights=weights, - bias=False, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - output, speculative_logits = self.fc(input) - additional_features = self.additional_fc(input) - output = torch.cat((output, additional_features), -1) - - return output, speculative_logits - - def extra_repr(self) -> str: - """Overwriting `nn.Linear.extra_repr` to include new parameters.""" - return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( - self.in_features, - self.out_features, - self.out_additional_features, - self.bias is not None, - self.partially_freeze, - ) - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -class IdeficsRMSNorm(nn.Module): - def __init__(self, prefix, weights, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - - weight = weights.get_tensor(f"{prefix}.weight") - self.weight = nn.Parameter(weight) - self.variance_epsilon = eps - - def forward(self, hidden_states, residual=None): - from vllm_hpu_extension.kernels import rms_norm - - orig_shape = hidden_states.shape - if residual is not None: - residual += hidden_states.view(residual.shape) - else: - residual = hidden_states - # Note: HPUFusedRMSNorm requires 3D tensors as inputs - if len(orig_shape) == 2: - residual = residual.unsqueeze(0) - x = rms_norm().apply(residual, self.weight, self.variance_epsilon) - return x.view(orig_shape), residual.view(orig_shape) - - -# this was adapted from LlamaMLP -class IdeficsMLP(nn.Module): - def __init__( - self, - config, - prefix, - weights, - ): - super().__init__() - self.gate_up_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], - weights=weights, - dim=0, - bias=False, - ) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) - shape = gate_up_states.shape - gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2) - return self.down_proj( - self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1] - ) - - -# this was adapted from LlamaAttention -class IdeficsAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config, - prefix, - weights, - qk_layer_norms: bool = False, - is_cross_attention: bool = False, - ): - super().__init__() - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.dropout = config.dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.is_cross_attention = is_cross_attention - - # if not hasattr(nn.functional, "scaled_dot_product_attention"): - # raise ValueError("this model requires pytorch 2.0 or higher") - - if self.num_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_heads //= weights.process_group.size() - - if self.is_cross_attention: - # kv_input_dim = ( - # self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim - # ) - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=False - ) - self.k_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k_proj", weights=weights, bias=False - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=False - ) - else: - self.qkv = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) - self.o_proj = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.o_proj", weights=weights, bias=False - ) - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_dim, base=10000.0, device=weights.device - ) - self.qk_layer_norms = qk_layer_norms - if self.qk_layer_norms: - self.q_layer_norm = IdeficsRMSNorm( - prefix=f"{prefix}.q_layer_norm", - weights=weights, - eps=config.rms_norm_eps, - ) - self.k_layer_norm = IdeficsRMSNorm( - prefix=f"{prefix}.q_layer_norm", - weights=weights, - eps=config.rms_norm_eps, - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - is_cross_attention = self.is_cross_attention or key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - if is_cross_attention: - query_states = self.q_proj(hidden_states).view( - bsz, q_len, self.num_heads, self.head_dim - ) # .transpose(1, 2) - query_states = query_states.transpose(1, 2) - ( - _, - kv_len, - _, - ) = ( - key_value_states.size() - ) # Note that, in this case, `kv_len` == `kv_seq_len` - key_states = ( - self.k_proj(key_value_states) - .view(bsz, kv_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.v_proj(key_value_states) - .view(bsz, kv_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - else: - qkv = self.qkv(hidden_states) - query_states, key_states, value_states = qkv.split( - self.num_heads * self.head_dim, dim=2 - ) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ) # .transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_heads, self.head_dim - ) # . transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_heads, self.head_dim - ) # .transpose(1, 2) - kv_seq_len = q_len - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - max_s = max(kv_seq_len, q_len) - cos, sin = self.rotary_emb.get_cos_sin( - position_ids.view(-1), max_s, hidden_states.dtype - ) - - query_shape = query_states.shape - key_shape = key_states.shape - self.rotary_emb( - query_states.view(-1, *query_shape[2:]), - key_states.reshape(-1, *key_shape[2:]), - cos, - sin, - ) - - query_states = query_states.view(query_shape) - key_states = key_states.view(key_shape) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - if self.qk_layer_norms: - query_states = self.q_layer_norm(query_states) - key_states = self.k_layer_norm(key_states) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_output = nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout, - ) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - attn_weights = None - if output_attentions: - logger.warning_once( - "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead" - ) - - return attn_output, attn_weights, past_key_value - - -# this was adapted from LlamaDecoderLayer -class IdeficsDecoderLayer(nn.Module): - def __init__(self, layer_id: int, config: IdeficsConfig, weights): - super().__init__() - self.process_group = weights.process_group - self.hidden_size = config.hidden_size - prefix = f"model.layers.{layer_id}" - self.self_attn = IdeficsAttention( - config=config, - prefix=f"{prefix}.self_attn", - weights=weights, - qk_layer_norms=False, - is_cross_attention=False, - ) - self.mlp = IdeficsMLP( - config=config, - prefix=f"{prefix}.mlp", - weights=weights, - ) - self.input_layernorm = IdeficsRMSNorm( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = IdeficsRMSNorm( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) - self.dropout = config.dropout - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class IdeficsGatedCrossAttentionLayer(nn.Module): - def __init__(self, layer_id, config: IdeficsConfig, weights): - super().__init__() - self.process_group = weights.process_group - self.hidden_size = config.hidden_size - prefix = f"model.gated_cross_attn_layers.{layer_id}" - self.cross_attn = IdeficsAttention( - config=config, - prefix=f"{prefix}.cross_attn", - weights=weights, - qk_layer_norms=True, - is_cross_attention=True, - ) - self.mlp = IdeficsMLP( - config=config, - prefix=f"{prefix}.mlp", - weights=weights, - ) - self.input_layernorm = IdeficsRMSNorm( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = IdeficsRMSNorm( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) - self.config = config.dropout - - self.act_cross_attn = nn.Tanh() - self.act_dense = nn.Tanh() - - self.alpha_cross_attn = nn.Parameter( - weights.get_tensor(f"{prefix}.alpha_cross_attn") - ) - self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense")) - - if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): - raise ValueError("Alpha parameters not initialized correctly!") - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_hidden_states: Optional[torch.Tensor] = None, - image_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - no_images: Optional[bool] = False, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored - """ - if image_hidden_states is None: - raise ValueError( - "`image_hidden_states` is required for Idefics cross attention module which are visual features to be" - " conditioned on." - ) - - if past_key_value is not None: - raise NotImplementedError( - "Past key value states are not implemented for Idefics cross attention module." - ) - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.cross_attn( - hidden_states=hidden_states, - key_value_states=image_hidden_states, - attention_mask=image_attention_mask, - output_attentions=output_attentions, - ) - # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) - # when there are no images the model is used in pure language mode - gate = 0 if no_images else 1 - hidden_states = ( - residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states - ) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) - hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`IdeficsConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -# @add_start_docstrings( -# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", -# LLAMA_START_DOCSTRING, -# ) -class IdeficsPreTrainedModel(PreTrainedModel): - config_class = IdeficsConfig - # base_model_prefix = "model" - # supports_gradient_checkpointing = True - # _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] - - # def _init_weights(self, module): - # # important: this ported version of Idefics isn't meant for training from scratch - only - # # inference and fine-tuning - so the proper init weights code has been removed - the m4 code - # # base should be used for training from scratch and it contains the correct code. - # std = self.config.initializer_range - # if isinstance(module, nn.Linear): - # module.weight.data.normal_(mean=0.0, std=std) - # if module.bias is not None: - # module.bias.data.zero_() - # elif isinstance(module, nn.Embedding): - # module.weight.data.normal_(mean=0.0, std=std) - # if module.padding_idx is not None: - # module.weight.data[module.padding_idx].zero_() - - # def _set_gradient_checkpointing(self, module, value=False): - # if isinstance(module, IdeficsModel): - # module.gradient_checkpointing = value - - -# LLAMA_INPUTS_DOCSTRING = r""" -# Args: -# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): -# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide -# it. - -# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and -# [`PreTrainedTokenizer.__call__`] for details. - -# [What are input IDs?](../glossary#input-ids) -# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): -# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - -# - 1 for tokens that are **not masked**, -# - 0 for tokens that are **masked**. - -# [What are attention masks?](../glossary#attention-mask) - -# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and -# [`PreTrainedTokenizer.__call__`] for details. - -# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see -# `past_key_values`). - -# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] -# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more -# information on the default strategy. - -# - 1 indicates the head is **not masked**, -# - 0 indicates the head is **masked**. -# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): -# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, -# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) -# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): -# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape -# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape -# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - -# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention -# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - -# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that -# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all -# `decoder_input_ids` of shape `(batch_size, sequence_length)`. -# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): -# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This -# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the -# model's internal embedding lookup matrix. -# use_cache (`bool`, *optional*): -# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see -# `past_key_values`). -# output_attentions (`bool`, *optional*): -# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned -# tensors for more detail. -# output_hidden_states (`bool`, *optional*): -# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for -# more detail. -# return_dict (`bool`, *optional*): -# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -# """ - - -# @add_start_docstrings( -# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", -# LLAMA_START_DOCSTRING, -# ) -class IdeficsModel(IdeficsPreTrainedModel): - # """ - # Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`] - - # Args: - # config: IdeficsConfig - # """ - - def __init__(self, config: IdeficsConfig, weights): - super().__init__(config) - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = IdeficsDecoupledPartialTPEmbedding( - config=config, - weights=weights, - ) - - self.image_size = config.vision_config.image_size - self.vision_config = config.vision_config - self.vision_model = IdeficsVisionTransformer( - prefix="model.vision_model", - config=config.vision_config, - weights=weights, - ) - - # Perceiver Resampler - if config.use_resampler: - perceiver_config = config.perceiver_config - self.perceiver_resampler = IdeficsPerceiverResampler( - prefix="model.perceiver_resampler", - config=config, - embed_dim=config.vision_config.embed_dim, - depth=perceiver_config.resampler_depth, - n_heads=perceiver_config.resampler_n_heads, - head_dim=perceiver_config.resampler_head_dim, - n_latents=perceiver_config.resampler_n_latents, - weights=weights, - ) - - self.layers = nn.ModuleList( - [ - IdeficsDecoderLayer(layer_id, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - - self.cross_layer_interval = config.cross_layer_interval - num_cross_layers = config.num_hidden_layers // self.cross_layer_interval - self.gated_cross_attn_layers = nn.ModuleList( - [ - IdeficsGatedCrossAttentionLayer(layer_id, config, weights) - for layer_id in range(num_cross_layers) - ] - ) - # self.gradient_checkpointing = False - - self.norm = IdeficsRMSNorm( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps - ) - - # self.gradient_checkpointing = False - # Initialize weights and apply final processing - # self.post_init() - - # self.freeze_relevant_params(config) - - # def freeze_relevant_params(self, config=None): - # if config is None: - # config = self.config - - # if config.freeze_text_layers: - # self.freeze_text_layers(config.freeze_text_module_exceptions) - - # if config.freeze_vision_layers: - # freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) - - # def freeze_text_layers(self, module_exceptions=[]): - # for module in [self.layers, self.norm]: - # freeze_model(module, module_exceptions=module_exceptions) - - # def freeze_vision_layers(self, module_exceptions=[]): - # freeze_model(self.vision_model, module_exceptions=module_exceptions) - - # def get_input_embeddings(self): - # return self.embed_tokens - - # def set_input_embeddings(self, value): - # self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - image_embeddings: Optional[torch.FloatTensor] = None, - image_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastImage]: - device = input_ids.device if input_ids is not None else inputs_embeds.device - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - elif position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - no_images = False - - if image_hidden_states is None: - if pixel_values is None and image_embeddings is None: - raise ValueError( - "Either pixel_values and image_embeddings have to be not-None." - ) - - elif pixel_values is not None and image_embeddings is not None: - raise ValueError( - "You cannot specify both pixel_values and image_embeddings at the same time" - ) - - elif pixel_values is not None: - no_images = len(torch.nonzero(pixel_values)) == 0 - pixel_values = pixel_values.to( - dtype=self.dtype, device=device - ) # fp16 compatibility - batch_size, num_images = pixel_values.shape[:2] - pixel_values = pixel_values.contiguous().view( - batch_size * num_images, *pixel_values.shape[2:] - ) - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values - ).last_hidden_state - - elif image_embeddings is not None: - ( - batch_size, - num_images, - image_seq_len, - image_hidden_size, - ) = image_embeddings.size() - image_hidden_states = image_embeddings.to( - dtype=self.dtype, device=input_ids.device - ) - image_hidden_states = image_hidden_states.view( - batch_size * num_images, image_seq_len, image_hidden_size - ) - - if self.config.use_resampler: - image_hidden_states = self.perceiver_resampler(image_hidden_states) - image_seq_len, image_hidden_size = image_hidden_states.size( - 1 - ), image_hidden_states.size(2) - image_hidden_states = image_hidden_states.view( - batch_size, num_images * image_seq_len, image_hidden_size - ) - else: - no_images = False - num_images = pixel_values.shape[1] - image_seq_len = image_hidden_states.shape[1] // num_images - - # # Hack to use the model in full language modeling mode - # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device) - # Make image_attention_mask compatible with hidden states - text_seq_len = image_attention_mask.size(1) - image_attention_mask = image_attention_mask.unsqueeze(-1) - image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len) - image_attention_mask = image_attention_mask.view( - batch_size, text_seq_len, num_images * image_seq_len - ) - image_batch_size, image_sequence_length, _ = image_hidden_states.size() - image_hidden_shape = (image_batch_size, image_sequence_length) - if image_attention_mask is None: - image_attention_mask = torch.ones(image_hidden_shape, device=device) - image_attention_mask = self.invert_attention_mask(image_attention_mask) - - # if list(image_attention_mask.shape) != [4, 1, 1024, 64]: - # raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}") - - # if image_hidden_states is not None: - # else: - # image_attention_mask = None - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - hidden_states = inputs_embeds - - # if self.gradient_checkpointing and self.training: - # if use_cache: - # logger.warning_once( - # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - # ) - # use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - def vblock( - main_block, - hidden_states, - attention_mask, - position_ids, - past_key_value, - image_hidden_states, - image_attention_mask, - output_attentions, - use_cache, - no_images, - layer_idx, - cross_layer_interval, - gated_cross_attn_layers, - ): - # TODO(ls): Add cross attention values to respective lists - if layer_idx % cross_layer_interval == 0: - xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] - outputs = xblock( - hidden_states, - attention_mask=attention_mask, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - output_attentions=output_attentions, - use_cache=use_cache, - past_key_value=None, # not implemented - no_images=no_images, - ) - hidden_states = outputs[0] - - layer_outputs = main_block( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - return layer_outputs - - # if self.gradient_checkpointing and self.training: - # past_key_value = None - # if use_cache: - # logger.warning_once( - # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - # ) - # use_cache = False - - # layer_outputs = torch.utils.checkpoint.checkpoint( - # vblock, - # decoder_layer, - # hidden_states, - # attention_mask, - # position_ids, - # past_key_value, - # image_hidden_states, - # image_attention_mask, - # output_attentions, - # use_cache, - # no_images, - # idx, - # self.cross_layer_interval, - # self.gated_cross_attn_layers, - # ) - # else: - layer_outputs = vblock( - decoder_layer, - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - output_attentions=output_attentions, - use_cache=use_cache, - no_images=no_images, - layer_idx=idx, - cross_layer_interval=self.cross_layer_interval, - gated_cross_attn_layers=self.gated_cross_attn_layers, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPastImage( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - image_hidden_states=image_hidden_states, - ) - - -class IdeficsForVisionText2Text(IdeficsPreTrainedModel): - def __init__( - self, - config, - weights, - ): - super().__init__(config) - self.model = IdeficsModel( - config=config, - weights=weights, - ) - - self.lm_head = IdeficsDecoupledTensorParallelLinear( - config=config, - weights=weights, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - image_embeddings: Optional[torch.FloatTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - image_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPastImage]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_embeddings=image_embeddings, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits, speculative_logits = self.lm_head(hidden_states) - - loss = None - - return ( - CausalLMOutputWithPastImage( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) - unwanted_kwargs = ["token_type_ids"] - for kwarg in unwanted_kwargs: - inputs.pop(kwarg, None) - return inputs - - @staticmethod - def _expand_inputs_for_generation( - *args, - **model_kwargs, - ): - return expand_inputs_for_generation(*args, **model_kwargs) - - @staticmethod - def _update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=False - ): - return update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder - ) - - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py deleted file mode 100644 index 6da8045b..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py +++ /dev/null @@ -1,276 +0,0 @@ -# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License. -# -# MIT License -# -# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - -""" - -Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially -time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note -that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to -prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that -to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore. - -References: - - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model - - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch - -""" -from typing import Optional, Tuple - -import torch -import torch.nn as nn - -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelRowLinear, -) - -EPS = 1e-5 - - -class IdeficsPerceiverResampler(nn.Module): - def __init__( - self, - prefix, - config, - embed_dim: int, - depth: int, - n_heads: int, - head_dim: int, - n_latents: int, - weights, - ) -> None: - """ - Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or - MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then - returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed - to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. - Could be e.g., VIT embed_dim, ResNet pool dim, and so on. - - Args: - config (`IdeficsConfig`): config object - embed_dim (`int`): The size of each embedding vector - depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). - n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention). - head_dim (`int`): Dimensionality of each head projection in the Transformer block. - n_latents (`int`): - Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). - - """ - super().__init__() - self.embed_dim, self.n_heads, self.head_dim, self.n_latents = ( - embed_dim, - n_heads, - head_dim, - n_latents, - ) - self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver - - # Create Latents for Perceiver - self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents")) - - self.intermediate_dim = ( - self.embed_dim * 4 - if not hasattr(config.vision_config, "embed_dim") - else config.vision_config.embed_dim * 4 - ) - # Create Transformer Blocks - self.blocks = nn.ModuleList( - [ - nn.ModuleList( - [ - IdeficsPerceiverAttention( - prefix=f"{prefix}.blocks.{layer_id}.0", - config=config, - embed_dim=self.embed_dim, - n_heads=self.n_heads, - head_dim=self.head_dim, - qk_layer_norms=self.qk_layer_norms, - weights=weights, - ), - IdeficsMLP( - prefix=f"{prefix}.blocks.{layer_id}.1", - intermediate_size=self.intermediate_dim, - config=config, - weights=weights, - ), - ] - ) - for layer_id in range(depth) - ] - ) - self.layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS - ) - - def forward(self, context: torch.Tensor) -> torch.Tensor: - """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" - # einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) - latents = self.latents.repeat(context.shape[0], 1, 1) - - # Feed through Perceiver Attention blocks... - for attn, ff in self.blocks: - latents = attn(context, latents) + latents - latents = ff(latents) + latents - - return self.layer_norm(latents) - - -class IdeficsPerceiverAttention(nn.Module): - def __init__( - self, - prefix, - config, - embed_dim: int, - n_heads: int, - head_dim: int, - qk_layer_norms: bool, - weights, - ) -> None: - """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" - super().__init__() - self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim - self.qk_layer_norms = qk_layer_norms - # Normalization & Scaling - self.context_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS - ) - self.latents_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS - ) - if self.qk_layer_norms: - self.q_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS - ) - self.k_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS - ) - - self.qk_scale = self.head_dim**-0.5 - - if n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads //= weights.process_group.size() - - # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). - self.q_proj = TensorParallelColumnLinear.load( - config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False - ) - self.k_proj = TensorParallelColumnLinear.load( - config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False - ) - self.v_proj = TensorParallelColumnLinear.load( - config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False - ) - - self.output_proj = TensorParallelRowLinear.load( - config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False - ) - - def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: - """ - Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! - - Args: - context (`torch.Tensor`): - Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample. - latents (`torch.Tensor`): - Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to. - - Returns: - `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross - from context. - """ - context = self.context_layer_norm(context) - latents = self.latents_layer_norm(latents) - batch_size, seq_length, embed_dim = context.shape[:3] - - # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! - # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` - q = self.q_proj(latents) - k = self.k_proj(torch.cat([context, latents], dim=-2)) - v = self.v_proj(torch.cat([context, latents], dim=-2)) - - # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) - # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] - # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) - q, k, v = [ - x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose( - 1, 2 - ) - for x in (q, k, v) - ] - - if self.qk_layer_norms: - q = self.q_layer_norm(q) - k = self.k_layer_norm(k) - - scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k) - stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach()) - attn = stabilized_scores.softmax(dim=-1) - - # Attend & project back to output... - resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v) - # einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads) - return self.output_proj(resampled.transpose(1, 2).flatten(-2)) - - -class IdeficsMLP(nn.Module): - def __init__( - self, - prefix, - intermediate_size, - config, - weights, - ): - """Simple MLP block with intermediate_size and embedding size""" - super().__init__() - self.embed_dim = config.vision_config.embed_dim - self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS) - self.fc = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.fc", - weights=weights, - bias=False, - ) - self.act = nn.ReLU() - self.c_proj = TensorParallelRowLinear.load( - config=config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=False, - ) - - def forward( - self, hidden_states: Optional[Tuple[torch.FloatTensor]] - ) -> torch.FloatTensor: - hidden_states = self.ln(hidden_states) - hidden_states = self.fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - - return hidden_states diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py deleted file mode 100644 index ca61e27d..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py +++ /dev/null @@ -1,443 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Processor class for IDEFICS. -""" - -from typing import Callable, List, Optional, Union -from urllib.parse import urlparse - -from transformers.feature_extraction_utils import BatchFeature -from transformers.processing_utils import ProcessorMixin -from transformers.tokenization_utils_base import ( - BatchEncoding, - PaddingStrategy, - TextInput, - TruncationStrategy, -) -from transformers.utils import TensorType, is_torch_available - - -if is_torch_available(): - import torch - - -IMAGE_TOKEN = "" - - -# copied from m4.training.packing -def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): - # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] - - # If any of images index are more than num_classes, set them to -1. - # Words after the max number of images allowed have been seen don't attend on anything - if num_classes != -1: - incremental_mask[incremental_mask >= num_classes] = -1 - - negatives = incremental_mask == -1 - incremental_mask[negatives] = 0 - attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) - attn_mask[negatives, :] = 0 - return attn_mask - - -# copied from m4.training.packing -def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): - image_attention_mask = torch.full_like(input_ids, fill_value=-1) - next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - eod_token_id = tokenizer.eos_token_id - for batch_idx in range(input_ids.size(0)): - count = -1 - seen_eod = False - for idx, token_id in enumerate(input_ids[batch_idx]): - if token_id == image_token_id: - count += 1 - image_attention_mask[batch_idx][idx] = count - seen_eod = False - else: - image_attention_mask[batch_idx][idx] = count - - if seen_eod: - image_attention_mask[batch_idx][idx] = -1 - - if token_id == eod_token_id: - seen_eod = True - - for batch_idx in range(input_ids.size(0)): - count = -1 - seen_eod = False - for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1): - token_id = input_ids[batch_idx][idx] - if token_id == image_token_id: - count += 1 - next_image_attention_mask[batch_idx][idx] = count - seen_eod = False - else: - next_image_attention_mask[batch_idx][idx] = count - - if token_id == eod_token_id: - seen_eod = True - - if seen_eod: - next_image_attention_mask[batch_idx][idx] = -1 - - non_negative_indices = next_image_attention_mask[batch_idx] != -1 - next_image_attention_mask[batch_idx][non_negative_indices] -= count - next_image_attention_mask[batch_idx][non_negative_indices] *= -1 - - return image_attention_mask, next_image_attention_mask - - -def is_url(string): - """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately - invalidated the url""" - if " " in string: - return False - result = urlparse(string) - return all([result.scheme, result.netloc]) - - -def is_image(string): - """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately - invalidated the url""" - return is_url(string) or string.startswith("data:") - - -class IdeficsProcessor(ProcessorMixin): - r""" - Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor. - - [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See - the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information. - - Args: - image_processor (`IdeficsImageProcessor`): - An instance of [`IdeficsImageProcessor`]. The image processor is a required input. - tokenizer (`LlamaTokenizerFast`): - An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. - image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image) - """ - - attributes = ["image_processor", "tokenizer"] - image_processor_class = "IdeficsImageProcessor" - tokenizer_class = "LlamaTokenizerFast" - - def __init__( - self, - image_processor, - tokenizer=None, - image_size=224, - add_end_of_utterance_token=None, - **kwargs, - ): - if image_processor is None: - raise ValueError("You need to specify an `image_processor`.") - if tokenizer is None: - raise ValueError("You need to specify a `tokenizer`.") - - super().__init__(image_processor, tokenizer) - self.current_processor = self.image_processor - self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - - self.default_image_dims = ( - self.image_processor.image_num_channels, - self.image_processor.image_size, - self.image_processor.image_size, - ) - - self.tokenizer_was_trained_with_end_of_utterance_token = ( - True - if "" - in self.tokenizer.special_tokens_map.get("additional_special_tokens", []) - else False - ) - - def __call__( - self, - prompts: Union[List[TextInput], List[List[TextInput]]], - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - transform: Callable = None, - add_eos_token=False, - add_end_of_utterance_token=None, - debug=False, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - ) -> BatchEncoding: - """This method takes batched or non-batched prompts made of text and images and converts them into prompts that - the model was trained on and prepares the image pixel values for the model to process. - - Args: - prompts (`Union[List[TextInput], [List[List[TextInput]]]]`): - either a single prompt or a batched list of prompts - see the detailed description immediately after - the end of the arguments doc section. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - transform (`Callable`, *optional*): - A custom transform function that accepts a single image can be passed for training. For example, - `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific - set of transforms will be applied to the images - add_eos_token (`bool`, *optional*, defaults to `False`): - Adds `eos_token` at the end of the final prompt if True` - add_end_of_utterance_token (`bool`, *optional*) - Whether to automatically add `` after each prompt's text input (unless followed by an - image). If `None` the tokenizer will be checked instead and if this token is found in - `additional_special_tokens` then the value will be `True`. - debug (`bool`, *optional*, defaults to `False`): - `True` value will help debug prompt generation by dumping useful information - return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`): - The type of tensors to return. Can be one of: - - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - - Returns: - a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be - directly passed to `model.generate` - - Detailed explanation: - - Each entry in `prompts` is either a text to be passed as is or an image that will be processed. - - An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved. - - When the processor encounters an image it'll inject `` - entry into the prompt. - - Example: - - ```python - checkpoint = "HuggingFaceM4/idefics-9b" - processor = AutoProcessor.from_pretrained(checkpoint) - url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg" - img = processor.image_processor.fetch_images([url])[0] - - prompts = [ - "User:", - img, - "Describe this image.\nAssistant: An image of two kittens in grass.\n", - "User:", - "https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg", - "Describe this image.\nAssistant:", - ] - - inputs = processor(prompts, return_tensors="pt") - generated_ids = model.generate(**inputs, max_length=100) - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - ``` - - In this example the `prompts` will be converted into: - - ``` - User:Describe this image. - Assistant: An image of two kittens in grass. - User:Describe this image. - Assistant:' - ``` - - and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the - `pixel_values` dict entry of the return value. - - This example also examplifies that images can be passed as objects or as text urls. It can be seen that the - first image is passed as object and the second one as a url. - - To do training do: - - ```python - image_transform = transforms.Compose( - [ - transforms.RandomResizedCrop( - (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC - ), - transforms.ToTensor(), - transforms.Normalize(mean=self.image_mean, std=self.image_std), - ] - ) - inputs = processor(prompts, transform=image_transform, return_tensors="pt") - ``` - - In order to help debug prompt generation enable `debug=True` which will show you what's happening. - - """ - - # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it - if add_end_of_utterance_token is None: - add_end_of_utterance_token = ( - self.tokenizer_was_trained_with_end_of_utterance_token - ) - - # turn non-batched prompts into batched - if not any(isinstance(i, list) for i in prompts): - prompts = [prompts] - - fake_token = "" - image_token = "" - end_of_utterance_token = "" - - def image_tokens(last_was_image): - if last_was_image: - return image_token + fake_token - else: - return fake_token + image_token + fake_token - - all_texts = [] - all_images = [] - for sample in prompts: - # the model was trained on samples starting with - full_text = f"{self.tokenizer.bos_token}" - - # an image can either be an image object in the item or the url, everything else is a verbatim prompt text - image_objects = [] - last_was_image = False - last_was_text = False - for i, item in enumerate(sample): - if i > 0: - last_was_text = True if not last_was_image else False - - if isinstance(item, str): - item = item.strip(" ") - if is_image(item): - image = self.image_processor.fetch_images(item) - full_text += image_tokens(last_was_image) - image_objects.append(image) - last_was_image = True - else: - # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!) - if add_end_of_utterance_token and last_was_text: - full_text += end_of_utterance_token - full_text += item - last_was_image = False - else: - # must be an image obj - full_text += image_tokens(last_was_image) - image_objects.append(item) - last_was_image = True - - if add_eos_token: - full_text += self.tokenizer.eos_token - - if debug is True: - print(f"{full_text=}") - - image_objects = self.image_processor(image_objects, transform=transform) - - text_encoding = self.tokenizer( - text=full_text, - add_special_tokens=False, - padding=padding, - truncation=truncation, - max_length=max_length, - ) - - all_texts.append(text_encoding["input_ids"]) - all_images.append(image_objects) - - max_seq_len = max(len(x) for x in all_texts) - - # max_num_images has to be at least 1 even when there are no images - max_num_images = max(len(x) for x in all_images) - max_num_images = max(1, max_num_images) - - at_least_one_image = sum(len(x) for x in all_images) > 0 - output_input_ids = [] - output_images = [] - output_attention_masks = [] - for text, images in zip(all_texts, all_images): - padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len - unpadded_seq_len = len(text) - start = max_seq_len - unpadded_seq_len - padded_input_ids[start:] = text[:max_seq_len] - - attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) - attention_mask[start:] = 1 - - image_count = padded_input_ids.count(self.image_token_id) - local_max_num_images = min(image_count, max_num_images) - - current_images = images[:local_max_num_images] - - if len(current_images) > 0: - padded_image_tensor = torch.zeros( - max_num_images, *current_images.size()[1:] - ) - padded_image_tensor[: current_images.size(0)] = current_images - else: - padded_image_tensor = torch.zeros( - max_num_images, *self.default_image_dims - ) - - output_images.append(padded_image_tensor) - output_input_ids.append(torch.tensor(padded_input_ids)) - - output_attention_masks.append(attention_mask) - - output_input_ids = torch.stack(output_input_ids) - output_images = torch.stack(output_images) - output_attention_masks = torch.stack(output_attention_masks) - - if at_least_one_image: - image_attention_mask, _ = image_attention_mask_for_packed_input_ids( - output_input_ids, self.tokenizer - ) - image_attention_mask = incremental_to_binary_attention_mask( - image_attention_mask, num_classes=max_num_images - ) - else: - # in full language mode we set the image mask to all-0s - image_attention_mask = torch.zeros( - output_input_ids.shape[0], - output_input_ids.shape[1], - 1, - dtype=torch.bool, - ) - - return BatchFeature( - data={ - "input_ids": output_input_ids, - "attention_mask": output_attention_masks, - "pixel_values": output_images, - "image_attention_mask": image_attention_mask, - } - ) - - def batch_decode(self, *args, **kwargs): - """ - This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please - refer to the docstring of this method for more information. - """ - return self.tokenizer.batch_decode(*args, **kwargs) - - def decode(self, *args, **kwargs): - """ - This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to - the docstring of this method for more information. - """ - return self.tokenizer.decode(*args, **kwargs) - - @property - def model_input_names(self): - tokenizer_input_names = self.tokenizer.model_input_names - image_processor_input_names = self.image_processor.model_input_names - return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py deleted file mode 100644 index 7d2051e0..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py +++ /dev/null @@ -1,529 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" - - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from transformers.utils import ( - ModelOutput, - logging, -) -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelRowLinear, - TensorParallelEmbedding, -) - -logger = logging.get_logger(__name__) - - -@dataclass -class IdeficsVisionModelOutput(ModelOutput): - """ - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - - Args: - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - image_embeds: Optional[torch.FloatTensor] = None - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics -class IdeficsVisionEmbeddings(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.class_embedding = nn.Parameter( - weights.get_tensor(f"{prefix}.class_embedding") - ) - - self.patch_embedding = nn.Conv2d.load_no_bias( - prefix=f"{prefix}.patch_embedding", - weights=weights, - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + 1 - self.position_embedding = TensorParallelEmbedding( - prefix="model.vision_model.embeddings.position_embedding", weights=weights - ) - self.position_ids = ( - torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device) - ) - - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype) - ) # shape = [*, width, grid, grid] - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision -class IdeficsVisionAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, prefix, config, weights): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - - if self.num_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_heads = self.num_heads // weights.process_group.size() - self.embed_dim = self.embed_dim // weights.process_group.size() - - self.k_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k_proj", weights=weights, bias=True - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=True - ) - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=True - ) - self.out_proj = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.out_proj", weights=weights, bias=True - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + causal_attention_mask - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + attention_mask - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit akward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision -class IdeficsVisionMLP(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.fc1", weights=weights, bias=True - ) - self.fc2 = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.fc2", weights=weights, bias=True - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision -class IdeficsVisionEncoderLayer(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.embed_dim = config.hidden_size - self.self_attn = IdeficsVisionAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights - ) - self.layer_norm1 = nn.LayerNorm.load( - prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps - ) - self.mlp = IdeficsVisionMLP( - prefix=f"{prefix}.mlp", config=config, weights=weights - ) - self.layer_norm2 = nn.LayerNorm.load( - prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - causal_attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - `(config.encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision -class IdeficsVisionEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`IdeficsVisionEncoderLayer`]. - - Args: - config: IdeficsVisionConfig - """ - - def __init__(self, prefix, config, weights): - super().__init__() - self.config = config - self.layers = nn.ModuleList( - [ - IdeficsVisionEncoderLayer( - prefix=f"{prefix}.encoder.layers.{layer_id}", - config=config, - weights=weights, - ) - for layer_id in range(config.num_hidden_layers) - ] - ) - # self.gradient_checkpointing = False - - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Causal mask for the text model. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # if self.gradient_checkpointing and self.training: - - # def create_custom_forward(module): - # def custom_forward(*inputs): - # return module(*inputs, output_attentions) - - # return custom_forward - - # layer_outputs = torch.utils.checkpoint.checkpoint( - # create_custom_forward(encoder_layer), - # hidden_states, - # attention_mask, - # causal_attention_mask, - # ) - # else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, encoder_states, all_attentions] - if v is not None - ) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, - ) - - -# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer -class IdeficsVisionTransformer(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.config = config - - self.embeddings = IdeficsVisionEmbeddings( - prefix=f"{prefix}.embeddings", config=config, weights=weights - ) - self.pre_layrnorm = nn.LayerNorm.load( - prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps - ) - self.encoder = IdeficsVisionEncoder( - prefix=prefix, config=config, weights=weights - ) - self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - - # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward - def forward( - self, - pixel_values: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - hidden_states = self.embeddings(pixel_values) - hidden_states = self.pre_layrnorm(hidden_states) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - pooled_output = last_hidden_state[:, 0, :] - pooled_output = self.post_layernorm(pooled_output) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - )