From 765ca78014f0c601a31bb80f8e17c761a479fbce Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 7 Jan 2025 22:05:47 +0000 Subject: [PATCH] fix: clean up idefics 3 and improve prefix handling --- .../custom_modeling/flash_llama_modeling.py | 27 +- .../models/custom_modeling/idefics3.py | 509 +----------------- .../models/custom_modeling/vlm.py | 2 +- .../models/vlm_causal_lm.py | 61 ++- 4 files changed, 76 insertions(+), 523 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 20bab01b..7525940a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -632,20 +632,24 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" super().__init__() - base_model = "" if prefix.endswith("text_model") else ".model" - with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" + f"{name}.embed_tokens" if not prefix - else f"{prefix}{base_model}.embed_tokens" + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: @@ -656,18 +660,13 @@ class FlashLlamaForCausalLM(torch.nn.Module): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - if not prefix: - head_prefix = suffix - elif prefix.endswith("text_model"): - head_prefix = suffix - else: - head_prefix = f"{prefix}.{suffix}" + prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=head_prefix, - weights=weights, + prefix, + weights, ) # Used in Granite diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index 2c467877..81e03943 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Idefics2 model.""" +""" PyTorch Idefics3 model.""" from typing import List, Optional, Tuple @@ -50,7 +50,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class Idefics2VisionEmbeddings(nn.Module): +class Idefics3VisionEmbeddings(nn.Module): """ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable resolution. @@ -131,7 +131,7 @@ class Idefics2VisionEmbeddings(nn.Module): return embeddings -class Idefics2VisionAttention(nn.Module): +class Idefics3VisionAttention(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config @@ -229,7 +229,7 @@ class Idefics2VisionAttention(nn.Module): return attn_output -class Idefics2VisionMLP(nn.Module): +class Idefics3VisionMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config @@ -248,11 +248,11 @@ class Idefics2VisionMLP(nn.Module): return hidden_states -class Idefics2EncoderLayer(nn.Module): +class Idefics3EncoderLayer(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Idefics2VisionAttention( + self.self_attn = Idefics3VisionAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) self.layer_norm1 = nn.LayerNorm.load( @@ -261,7 +261,7 @@ class Idefics2EncoderLayer(nn.Module): self.layer_norm2 = nn.LayerNorm.load( prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights ) - self.mlp = Idefics2VisionMLP( + self.mlp = Idefics3VisionMLP( prefix=f"{prefix}.mlp", config=config, weights=weights ) @@ -288,13 +288,13 @@ class Idefics2EncoderLayer(nn.Module): return hidden_states -class Idefics2Encoder(nn.Module): +class Idefics3Encoder(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config self.layers = nn.ModuleList( [ - Idefics2EncoderLayer( + Idefics3EncoderLayer( prefix=f"{prefix}.layers.{i}", config=config, weights=weights ) for i in range(config.num_hidden_layers) @@ -316,14 +316,14 @@ class Idefics2Encoder(nn.Module): return hidden_states -class Idefics2VisionTransformer(nn.Module): +class Idefics3VisionTransformer(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config - self.embeddings = Idefics2VisionEmbeddings( + self.embeddings = Idefics3VisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights ) - self.encoder = Idefics2Encoder( + self.encoder = Idefics3Encoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) self.post_layernorm = nn.LayerNorm.load( @@ -377,317 +377,26 @@ class Idefics2VisionTransformer(nn.Module): return last_hidden_state -class Idefics2MLP(nn.Module): +class Idefics3SimpleMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - act = config.text_config.hidden_act - self.act = ( - ACT2FN[act] - if "gelu" not in act - else lambda x: torch.nn.functional.gelu( - x, - approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" - ), - ) - ) - self.gate_up_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], - weights=weights, - dim=0, - bias=False, - ) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ) + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + proj = nn.Parameter( + weights.get_tensor(f"{prefix}.modality_projection.proj.weight"), + requires_grad=False, + ).to(weights.dtype) + self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj.weight = proj - def forward(self, hidden_states): - start_shape = hidden_states.shape[:-1] - gate_up_states = self.gate_up_proj(hidden_states) - intermediate_size = gate_up_states.shape[-1] // 2 - gate_up_states = gate_up_states.view(-1, 2, intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] - ).view(*start_shape, -1) - - -class Idefics2RMSNorm(nn.Module): - def __init__(self, prefix, weights, eps): - """ - Idefics2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter( - weights.get_tensor(f"{prefix}.weight"), requires_grad=False - ) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class Idefics2PerceiverAttention(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.layer_idx = None - self.hidden_size = config.text_config.hidden_size - self.num_heads = config.perceiver_config.resampler_n_heads - self.head_size = config.perceiver_config.resampler_head_dim - self.num_key_value_heads = config.perceiver_config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.attention_dropout = config.perceiver_config.attention_dropout - self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - self.num_key_value_heads // weights.process_group.size() - ) - - self.q_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.q_proj", - weights=weights, - bias=False, - ) - self.kv = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) - self.o_proj = TensorParallelRowLinear.load( - config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False - ) - - self.is_causal = False - - def forward( - self, - latents: torch.Tensor, - context: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = latents.size() - kv_seq_len = q_len + context.size()[1] - - hidden_states = torch.concat([context, latents], dim=-2) - query_states = self.q_proj(latents) - kv = self.kv(hidden_states) - key_states, value_states = kv.split( - [ - self.head_size * self.num_key_value_heads, - self.head_size * self.num_key_value_heads, - ], - dim=2, - ) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_size - ).transpose(1, 2) - key_states = key_states.view( - bsz, kv_seq_len, self.num_key_value_heads, self.head_size - ).transpose(1, 2) - value_states = value_states.view( - bsz, kv_seq_len, self.num_key_value_heads, self.head_size - ).transpose(1, 2) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_size) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size) - - attn_output = self.o_proj(attn_output) - - return attn_output - - -class Idefics2PerceiverLayer(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.hidden_size = config.text_config.hidden_size - self.n_latents = config.perceiver_config.resampler_n_latents - self.depth = config.perceiver_config.resampler_depth - self.rms_norm_eps = config.text_config.rms_norm_eps - - self.input_latents_norm = Idefics2RMSNorm( - prefix=f"{prefix}.input_latents_norm", - weights=weights, - eps=self.rms_norm_eps, - ) - self.input_context_norm = Idefics2RMSNorm( - prefix=f"{prefix}.input_context_norm", - weights=weights, - eps=self.rms_norm_eps, - ) - self.self_attn = Idefics2PerceiverAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights - ) - self.post_attention_layernorm = Idefics2RMSNorm( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=self.rms_norm_eps, - ) - self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - - def forward( - self, - latents: torch.Tensor, - context: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - """ - Args: - latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - """ - residual = latents - - latents = self.input_latents_norm(latents) - context = self.input_context_norm(context) - - latents = self.self_attn( - latents=latents, - context=context, - attention_mask=attention_mask, - ) - latents = residual + latents - residual = latents - - latents = self.post_attention_layernorm(latents) - latents = self.mlp(latents) - latents = residual + latents - - return latents - - -class Idefics2PerceiverResampler(nn.Module): - def __init__(self, prefix, config, weights) -> None: - super().__init__() - self.hidden_size = config.text_config.hidden_size - self.hidden_act = config.perceiver_config.hidden_act - self.n_latents = config.perceiver_config.resampler_n_latents - self.depth = config.perceiver_config.resampler_depth - self.rms_norm_eps = config.text_config.rms_norm_eps - - # Create Latents for Perceiver - self.latents = weights.get_tensor(f"{prefix}.latents") - - # Create Transformer Blocks - self.layers = nn.ModuleList( - [ - Idefics2PerceiverLayer( - prefix=f"{prefix}.layers.{idx}", config=config, weights=weights - ) - for idx in range(self.depth) - ] - ) - self.norm = Idefics2RMSNorm( - prefix=f"{prefix}.norm", - weights=weights, - eps=config.text_config.rms_norm_eps, - ) - - def forward( - self, - context: torch.Tensor, - attention_mask, - ) -> torch.Tensor: - # seq embed -> bsz seq embed - latents = self.latents.unsqueeze(0).expand( - (context.shape[0], *self.latents.size()) - ) - - latent_attention_mask = torch.ones( - (attention_mask.size(0), latents.size(1)), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) - attention_mask = _prepare_4d_attention_mask( - attention_mask, latents.dtype, tgt_len=self.n_latents - ) - - compressed_context = latents - for perceiver_layer in self.layers: - compressed_context = perceiver_layer( - compressed_context, - context, - attention_mask=attention_mask, - ) - compressed_context = self.norm(compressed_context) - - return compressed_context - - -class Idefics2Connector(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.modality_projection = Idefics2MLP( - prefix=f"{prefix}.modality_projection", config=config, weights=weights - ) - self.perceiver_resampler = Idefics2PerceiverResampler( - prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights - ) - - def forward(self, image_hidden_states, attention_mask): - image_hidden_states = self.modality_projection(image_hidden_states) - image_hidden_states = self.perceiver_resampler( - context=image_hidden_states, attention_mask=attention_mask - ) - return image_hidden_states + def forward(self, x): + return self.proj(x) class Idefics3Connector(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.modality_projection = TensorParallelRowLinear.load( - prefix=f"{prefix}.modality_projection.proj", - config=config, - weights=weights, - bias=False, - ) + self.modality_projection = Idefics3SimpleMLP(prefix, config, weights) self.scale_factor = config.scale_factor def pixel_shuffle(self, x, scale_factor=2): @@ -706,8 +415,7 @@ class Idefics3Connector(nn.Module): x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) return x - def forward(self, image_hidden_states, attention_mask): - print(image_hidden_states.device, self.modality_projection.linear.weight.device) + def forward(self, image_hidden_states): image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) image_hidden_states = self.modality_projection(image_hidden_states) return image_hidden_states @@ -726,7 +434,7 @@ class Idefics3ForConditionalGeneration(nn.Module): vision_config = config.vision_config self.text_model = load_text_model( - prefix=f"{prefix}.model.text_model" if prefix else "model.text_model", + prefix="model" if not prefix else f"{prefix}.model", config=config.text_config, weights=weights, name="text_model", @@ -735,7 +443,7 @@ class Idefics3ForConditionalGeneration(nn.Module): # The vision and connector models are not quantized. with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): - self.vision_model = Idefics2VisionTransformer( + self.vision_model = Idefics3VisionTransformer( prefix=( f"{prefix}.model.vision_model" if prefix else "model.vision_model" ), @@ -810,7 +518,6 @@ class Idefics3ForConditionalGeneration(nn.Module): dim=(-1, -2, -3) ) != nb_values_per_image pixel_values = pixel_values[real_images_inds].contiguous() - # Handle the vision attention mask if pixel_attention_mask is None: pixel_attention_mask = torch.ones( @@ -850,7 +557,6 @@ class Idefics3ForConditionalGeneration(nn.Module): # Modality projection & resampling image_hidden_states = self.connector( image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), ) all_states.append(image_hidden_states) @@ -877,164 +583,3 @@ class Idefics3ForConditionalGeneration(nn.Module): hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits - - -class Idefics2ForConditionalGeneration(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - config.vision_config.quantize = None - config.vision_config.speculator = config.speculator - config.text_config.quantize = config.quantize - config.text_config.speculator = config.speculator - - vision_config = config.vision_config - self.text_model = load_text_model( - prefix="model" if not prefix else f"{prefix}.model", - config=config.text_config, - weights=weights, - name="text_model", - ) - self.dtype = weights.dtype - - # The vision and connector models are not quantized. - with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): - self.vision_model = Idefics2VisionTransformer( - prefix=( - f"{prefix}.model.vision_model" if prefix else "model.vision_model" - ), - config=vision_config, - weights=weights, - ) - - config.quantize = None - self.connector = Idefics2Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) - - self.config = config - self.image_seq_len = config.perceiver_config.resampler_n_latents - self.image_token_id = config.image_token_id - self.pad_token_id = ( - config.pad_token_id if config.pad_token_id is not None else -1 - ) - - def _merge_input_ids_with_image_features( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_features: torch.Tensor, - ): - """In place merges in vision_embeddings with inputs_embeds.""" - # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id - # Let's pray we have enabled enough slots ! - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - # Unused here - image_sizes: Optional[torch.Tensor] = None, - adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), - ) - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - - hidden_states = self.text_model.model( - inputs_embeds=inputs_embeds, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, - prefill_cache_indices=None, - adapter_data=adapter_data, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.text_model.lm_head(hidden_states) - return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 82e409a6..94b8522d 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 082f4b81..daf5d063 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLM, ) from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -29,25 +30,32 @@ IDEFICS3_GLOBAL_IMG_TOKEN = "" def get_image_prompt_string( - rows=0, - cols=0, - seq_len=1, - fake_token=IDEFICS3_FAKE_IMAGE_TOKEN, - img_token=IDEFICS3_IMAGE_TOKEN, - global_token=IDEFICS3_GLOBAL_IMG_TOKEN, + *, + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, ): - tokens = img_token * seq_len - end_token = f"{fake_token}{global_token}{tokens}{fake_token}" + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" - if rows == 0 or cols == 0: - return end_token - - grid = "\n".join( - "".join(f"{fake_token}{tokens}" for j in range(cols)) - for i in range(rows) + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" ) - - return f"{grid}\n\n{end_token}" + return text_split_images def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): @@ -89,18 +97,17 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str / (config.scale_factor**2) ) image_str = get_image_prompt_string( - rows=n_rows, - cols=n_cols, - seq_len=image_seq_len, - fake_token=IDEFICS3_FAKE_IMAGE_TOKEN, - img_token=IDEFICS3_IMAGE_TOKEN, - global_token=IDEFICS3_GLOBAL_IMG_TOKEN, + image_seq_len=image_seq_len, + image_rows=n_rows, + image_cols=n_cols, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + image_token=IDEFICS3_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, ) return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) - from loguru import logger log_master( logger.info, @@ -238,9 +245,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if images: kwargs = {} - match processor.image_processor_class: - case "Idefics3ImageProcessor": - kwargs["return_row_col_info"] = True + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True image_inputs = processor.image_processor( images, return_tensors="pt", **kwargs