diff --git a/router/client/src/client.rs b/router/client/src/client.rs index ae926139..545cddd0 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -114,7 +114,7 @@ impl Client { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let mut inputs = String::new(); - inputs.push_str("![]("); + inputs.push_str("![]()"); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); requests.push(Request { diff --git a/router/src/config.rs b/router/src/config.rs index 9b5a2404..0de0a56c 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -92,6 +92,7 @@ pub enum Config { ClipVisionModel(ClipVisionModel), Mistral, Idefics, + Idefics2, Ssm, GptBigcode, Santacoder, diff --git a/router/src/validation.rs b/router/src/validation.rs index 2029c7e0..94e59b2d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -540,7 +540,30 @@ fn prepare_input( inputs = modified_inputs; tokenizer_query } - Some(Config::Idefics) => RE.replace_all(&inputs, "").into(), + Some(Config::Idefics | Config::Idefics2) => { + let mut modified_inputs = String::with_capacity(inputs.len()); + let mut tokenizer_query = String::with_capacity(inputs.len()); + let mut start = 0; + for chunk in RE.find_iter(&inputs) { + let chunk_start = chunk.start(); + let chunk_end = chunk.end(); + if chunk_start != start { + modified_inputs.push_str(&inputs[start..chunk_start]); + tokenizer_query.push_str(&inputs[start..chunk_start]); + } + let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let slots = 1; + tokenizer_query.push_str(&"".repeat(slots)); + modified_inputs.push_str(&image_uri); + start = chunk_end; + } + if start != inputs.len() - 1 { + modified_inputs.push_str(&inputs[start..]); + tokenizer_query.push_str(&inputs[start..]); + } + inputs = modified_inputs; + tokenizer_query + } _ => inputs.clone(), }; diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 06792b0d..99cc7be8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -68,6 +68,7 @@ try: ) from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.llava_next import LlavaNext + from text_generation_server.models.idefics2 import Idefics2 from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_phi import FlashPhi @@ -579,6 +580,18 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == "idefics2": + if FLASH_ATTENTION: + return Idefics2( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == "llava_next": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ffaa0c32..e78260fc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, name=None): + if name is None: + name = "model" super().__init__() - self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" + f"{name}.embed_tokens" + if not prefix + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) self.model = MistralModel( - prefix="model" if not prefix else f"{prefix}.model", + prefix=name if not prefix else f"{prefix}.{name}", config=config, weights=weights, ) self.lm_head = SpeculativeHead.load( config, - prefix="lm_head" if not prefix else f"{prefix}.lm_head", + # TODO dirty hack for idefics2. + prefix=( + "lm_head" if not prefix or name is not "model" else f"{prefix}.lm_head" + ), weights=weights, ) self.max_past = config.sliding_window diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py new file mode 100644 index 00000000..dc51fbcd --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -0,0 +1,803 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics2 model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +import math + +from transformers.activations import ACT2FN +from transformers.image_processing_utils import select_best_resolution +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics2VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics2VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics2Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics2EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics2VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics2Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.text_config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.text_config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class Idefics2RMSNorm(nn.Module): + def __init__(self, prefix, weights, eps): + """ + Idefics2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Idefics2PerceiverAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.layer_idx = None + self.hidden_size = config.text_config.hidden_size + self.num_heads = config.perceiver_config.resampler_n_heads + self.head_size = config.perceiver_config.resampler_head_dim + self.num_key_value_heads = config.perceiver_config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.attention_dropout = config.perceiver_config.attention_dropout + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.text_config.num_key_value_heads // weights.process_group.size() + ) + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False + ) + + self.is_causal = False + + def forward( + self, + latents: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = latents.size() + kv_seq_len = q_len + context.size()[1] + + try: + hidden_states = torch.concat([context, latents], dim=-2) + except Exception as e: + print(e) + import ipdb + + ipdb.set_trace() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=2, + ) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + bsz, kv_seq_len, self.num_key_value_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + bsz, kv_seq_len, self.num_key_value_heads, self.head_size + ).transpose(1, 2) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_size) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Idefics2PerceiverLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.hidden_size = config.text_config.hidden_size + self.n_latents = config.perceiver_config.resampler_n_latents + self.depth = config.perceiver_config.resampler_depth + self.rms_norm_eps = config.text_config.rms_norm_eps + + self.input_latents_norm = Idefics2RMSNorm( + prefix=f"{prefix}.input_latents_norm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.input_context_norm = Idefics2RMSNorm( + prefix=f"{prefix}.input_context_norm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.self_attn = Idefics2PerceiverAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.post_attention_layernorm = Idefics2RMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + def forward( + self, + latents: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + Args: + latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + """ + residual = latents + + latents = self.input_latents_norm(latents) + context = self.input_context_norm(context) + + latents = self.self_attn( + latents=latents, + context=context, + attention_mask=attention_mask, + ) + latents = residual + latents + residual = latents + + latents = self.post_attention_layernorm(latents) + latents = self.mlp(latents) + latents = residual + latents + + return latents + + +class Idefics2PerceiverResampler(nn.Module): + def __init__(self, prefix, config, weights) -> None: + super().__init__() + self.hidden_size = config.text_config.hidden_size + self.hidden_act = config.perceiver_config.hidden_act + self.n_latents = config.perceiver_config.resampler_n_latents + self.depth = config.perceiver_config.resampler_depth + self.rms_norm_eps = config.text_config.rms_norm_eps + + # Create Latents for Perceiver + self.latents = weights.get_tensor(f"{prefix}.latents") + + # Create Transformer Blocks + self.layers = nn.ModuleList( + [ + Idefics2PerceiverLayer( + prefix=f"{prefix}.layers.{idx}", config=config, weights=weights + ) + for idx in range(self.depth) + ] + ) + self.norm = Idefics2RMSNorm( + prefix=f"{prefix}.norm", + weights=weights, + eps=config.text_config.rms_norm_eps, + ) + + def forward( + self, + context: torch.Tensor, + attention_mask, + ) -> torch.Tensor: + # seq embed -> bsz seq embed + latents = self.latents.unsqueeze(0).expand( + (context.shape[0], *self.latents.size()) + ) + + latent_attention_mask = torch.ones( + (attention_mask.size(0), latents.size(1)), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) + attention_mask = _prepare_4d_attention_mask( + attention_mask, latents.dtype, tgt_len=self.n_latents + ) + + compressed_context = latents + for perceiver_layer in self.layers: + compressed_context = perceiver_layer( + compressed_context, + context, + attention_mask=attention_mask, + ) + compressed_context = self.norm(compressed_context) + + return compressed_context + + +class Idefics2Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics2MLP( + prefix=f"{prefix}.modality_projection", config=config, weights=weights + ) + self.perceiver_resampler = Idefics2PerceiverResampler( + prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights + ) + + def forward(self, image_hidden_states, attention_mask): + image_hidden_states = self.modality_projection(image_hidden_states) + image_hidden_states = self.perceiver_resampler( + context=image_hidden_states, attention_mask=attention_mask + ) + return image_hidden_states + + +class Idefics2ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + config.vision_config.use_medusa = config.use_medusa + config.text_config.quantize = config.quantize + config.text_config.use_medusa = config.use_medusa + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + self.vision_model = Idefics2VisionTransformer( + prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model", + config=vision_config, + weights=weights, + ) + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + self.config = config + self.image_seq_len = config.perceiver_config.resampler_n_latents + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = input_ids == self.config.image_token_index + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values.view( + batch_size * num_images, *pixel_values.shape[2:] + ) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index ed21a52b..08ac9fcf 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -23,6 +23,10 @@ from torch import nn from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, @@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module): return hidden_states -def load_vision_model(prefix, config, weights): - if config.model_type == "clip_vision_model": - from text_generation_server.models.custom_modeling.clip import ( - CLIPVisionTransformer, - ) - - return CLIPVisionTransformer( - prefix=f"{prefix}.vision_model", config=config, weights=weights - ) - else: - raise RuntimeError(f"Unsupported model type {config.model_type}") - - -def load_text_model(prefix, config, weights): - if config.model_type == "llama": - from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, - ) - - return FlashLlamaForCausalLM(prefix, config, weights) - elif config.model_type == "mistral": - from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( - FlashMistralForCausalLM, - ) - - return FlashMistralForCausalLM(prefix, config, weights) - else: - raise RuntimeError(f"Unsupported model type {config.model_type}") - - class LlavaNextForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py new file mode 100644 index 00000000..690957d0 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -0,0 +1,28 @@ +def load_text_model(prefix, config, weights, name=None): + if config.model_type == "llama": + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + ) + + return FlashLlamaForCausalLM(prefix, config, weights) + elif config.model_type == "mistral": + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + + return FlashMistralForCausalLM(prefix, config, weights, name=name) + else: + raise RuntimeError(f"Unsupported model type {config.model_type}") + + +def load_vision_model(prefix, config, weights): + if config.model_type == "clip_vision_model": + from text_generation_server.models.custom_modeling.clip import ( + CLIPVisionTransformer, + ) + + return CLIPVisionTransformer( + prefix=f"{prefix}.vision_model", config=config, weights=weights + ) + else: + raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py index 342609de..d88ff574 100644 --- a/server/text_generation_server/models/idefics2.py +++ b/server/text_generation_server/models/idefics2.py @@ -1,31 +1,18 @@ import torch -import torch.distributed -from typing import List, Optional, Tuple +from typing import Optional, Tuple from transformers import ( - AutoTokenizer, - AutoConfig, AutoProcessor, ) - -from text_generation_server.models.custom_modeling.idefics2_config import IdeficsConfig -from text_generation_server.models.custom_modeling.idefics_processing import ( - IdeficsProcessor, -) -from transformers import LlamaTokenizerFast -from text_generation_server.models.custom_modeling.idefics2_modeling import ( - Idefics2ForVisionText2Text, -) -from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, +from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, ) +from text_generation_server.models.vlm_causal_lm import VlmCausalLM -class IDEFICS2Sharded(IdeficsCausalLM): + +class Idefics2(VlmCausalLM): def __init__( self, model_id: str, @@ -35,59 +22,25 @@ class IDEFICS2Sharded(IdeficsCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - # 9b seems to work correctly enough in float16, but 80b seems - # to be really saturating for f16. - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - self.device, self.dtype = device, dtype - - config = IdeficsConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, + self.processor = AutoProcessor.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code ) - config.quantize = quantize - config.use_medusa = use_medusa - config.vision_config.quantize = quantize - - tokenizer = LlamaTokenizerFast.from_pretrained( - model_id, + super().__init__( + model_cls=Idefics2ForConditionalGeneration, + model_id=model_id, revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - self.processor = IdeficsProcessor.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, + quantize=quantize, + use_medusa=use_medusa, dtype=dtype, - process_group=self.process_group, + trust_remote_code=trust_remote_code, ) - model = IdeficsForVisionText2Text(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(IdeficsCausalLM, self).__init__( - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, + def get_layer_config(self, model) -> Tuple[int, int, int]: + return ( + len(model.text_model.model.layers), + model.text_model.model.num_key_value_heads, + model.text_model.model.head_size, ) + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py index 0ae1b46d..3983bc85 100644 --- a/server/text_generation_server/models/llava_next.py +++ b/server/text_generation_server/models/llava_next.py @@ -1,6 +1,6 @@ import torch -from typing import Optional +from typing import Optional, Tuple from transformers import ( AutoProcessor, @@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM): dtype=dtype, trust_remote_code=trust_remote_code, ) + + def get_layer_config(self, model) -> Tuple[int, int, int]: + return ( + len(model.language_model.model.layers), + model.language_model.model.num_key_value_heads, + model.language_model.model.head_size, + ) + + def max_past(self) -> Optional[int]: + return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 16042fc9..5161a970 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -147,8 +147,10 @@ class VlmCausalLMBatch(FlashMistralBatch): "Cannot process input image not starting with data:" ) image_input = processor.image_processor(image, return_tensors="pt") - height, width = image_input["image_sizes"][0] - num_features = get_number_of_features(height, width, config) + # import ipdb;ipdb.set_trace() + # height, width = image_input["image_sizes"][0] + # num_features = get_number_of_features(height, width, config) + num_features = 1 full_text += "" * num_features image_inputs.append(image_input) else: @@ -165,7 +167,10 @@ class VlmCausalLMBatch(FlashMistralBatch): "pixel_values": torch.cat( [img["pixel_values"] for img in image_inputs], dim=0 ), - "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]), + "pixel_attention_mask": torch.cat( + [img["pixel_attention_mask"] for img in image_inputs], dim=0 + ), + # "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]), } else: image_inputs = None @@ -187,10 +192,14 @@ class VlmCausalLMBatch(FlashMistralBatch): batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: batch.pixel_values = image_inputs["pixel_values"].to(device=device) - batch.image_sizes = image_inputs["image_sizes"].to(device=device) + batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( + device=device + ) + # batch.image_sizes = image_inputs["image_sizes"].to(device=device) else: batch.pixel_values = None - batch.image_sizes = None + batch.pixel_attention_mask = None + # batch.image_sizes = None return batch @@ -199,16 +208,6 @@ class VlmCausalLM(BaseFlashMistral): def batch_type(self) -> Type[VlmCausalLMBatch]: return VlmCausalLMBatch - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.language_model.model.layers), - model.language_model.model.num_key_value_heads, - model.language_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.language_model, "max_past", None) - def forward( self, batch: VlmCausalLMBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -294,14 +293,17 @@ class VlmCausalLM(BaseFlashMistral): prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, - image_sizes=batch.image_sizes, + pixel_attention_mask=batch.pixel_attention_mask, + # image_sizes=batch.image_sizes, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None - if batch.image_sizes is not None: - batch.image_sizes = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + # if batch.image_sizes is not None: + # batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph