diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 06219e7c..37c41e3b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -67,6 +67,7 @@ try: FlashSantacoderSharded, ) from text_generation_server.models.idefics import IDEFICSSharded + from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_phi import FlashPhi @@ -571,6 +572,19 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == "llava_next": + if FLASH_ATTENTION: + return LlavaNext( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) + if sharded: raise NotImplementedError("sharded is not supported for AutoModel") if quantize == "gptq": diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py new file mode 100644 index 00000000..856c642c --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -0,0 +1,1043 @@ +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) +from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig + + +class CLIPVisionEmbeddings(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class CLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = ( + input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + causal_attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class CLIPMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: CLIPConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPPreTrainedModel(nn.Module): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CLIPConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_( + module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor + ) + nn.init.normal_( + module.patch_embedding.weight, + std=module.config.initializer_range * factor, + ) + nn.init.normal_( + module.position_embedding.weight, + std=module.config.initializer_range * factor, + ) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = ( + (module.embed_dim**-0.5) + * ((2 * module.config.num_hidden_layers) ** -0.5) + * factor + ) + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) + * ((2 * module.config.num_hidden_layers) ** -0.5) + * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, CLIPModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPVisionModelWithProjection): + nn.init.normal_( + module.visual_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPTextModelWithProjection): + nn.init.normal_( + module.text_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: CLIPConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + return hidden_states + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == self.eos_token_id + ) + .int() + .argmax(dim=-1), + ] + + return last_hidden_state + + +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, prefix, config: CLIPVisionConfig, weights): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPModel(nn.Module): + def __init__(self, prefix, config: CLIPConfig, weights): + super().__init__() + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear( + self.vision_embed_dim, self.projection_dim, bias=False + ) + self.text_projection = nn.Linear( + self.text_embed_dim, self.projection_dim, bias=False + ) + self.logit_scale = nn.Parameter( + torch.tensor(self.config.logit_scale_init_value) + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + return logits_per_image, logits_per_text diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ed9306e0..6631aba4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -285,9 +285,8 @@ class MistralMLP(nn.Module): class MistralLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" self.self_attn = MistralAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -343,27 +342,27 @@ class MistralLayer(nn.Module): class MistralModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ MistralLayer( - layer_id, - config, - weights, + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, ) for layer_id in range(config.num_hidden_layers) ] ) self.norm = FastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.gradient_checkpointing = False @@ -415,13 +414,17 @@ class MistralModel(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.model = MistralModel(config, weights) + self.model = MistralModel( + prefix="model" if not prefix else f"{prefix}.model", + config=config, + weights=weights, + ) self.lm_head = SpeculativeHead.load( config, - prefix="lm_head", + prefix="lm_head" if not prefix else f"{prefix}.lm_head", weights=weights, ) self.max_past = config.sliding_window diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py new file mode 100644 index 00000000..2392db8b --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -0,0 +1,424 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Llava-NeXT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.image_processing_utils import select_best_resolution +from transformers import AutoModel, AutoModelForCausalLM + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext +class LlavaNextMultiModalProjector(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=True + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def load_vision_model(prefix, config, weights): + if config.model_type == "clip_vision_model": + from text_generation_server.models.custom_modeling.clip import ( + CLIPVisionTransformer, + ) + + return CLIPVisionTransformer(prefix, config, weights) + else: + raise RuntimeError(f"Unsupported model type {config.model_type}") + + +def load_text_model(prefix, config, weights): + if config.model_type == "mistral": + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + + return FlashMistralForCausalLM(prefix, config, weights) + else: + raise RuntimeError(f"Unsupported model type {config.model_type}") + + +class LlavaNextForConditionalGeneration(nn.Module): + def __init__(self, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + self.vision_tower = load_vision_model( + prefix="vision_tower", config=config.vision_config, weights=weights + ) + + self.multi_modal_projector = LlavaNextMultiModalProjector(config) + + self.image_newline = weights.get_tensor("image_newline") + + self.vocab_size = config.text_config.vocab_size + config.text_config.quantize = config.quantize + config.text_config.use_medusa = config.use_medusa + self.language_model = load_text_model( + prefix="language_model", config=config.text_config, weights=weights + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + # self.post_init() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids, attention_mask, labels + ): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum( + input_ids[:, -1] == torch.tensor(self.pad_token_id) + ) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = ( + num_special_image_tokens.max() * (num_image_patches - 1) + ) + sequence_length + batch_indices, non_image_indices = torch.where( + input_ids != self.config.image_token_index + ) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = ( + torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) + - 1 + ) + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, + max_embed_dim, + embed_dim, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + final_attention_mask = torch.zeros( + batch_size, + max_embed_dim, + dtype=attention_mask.dtype, + device=inputs_embeds.device, + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), + self.config.ignore_index, + dtype=input_ids.dtype, + device=input_ids.device, + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ + batch_indices, non_image_indices + ] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ + batch_indices, non_image_indices + ] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[ + batch_indices, non_image_indices + ] + + # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling + image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[ + :, None + ].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = ( + image_features.contiguous().reshape(-1, embed_dim).to(target_device) + ) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( + (final_attention_mask == 0), 1 + ) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + batch_size, num_patches, num_channels, height, width = ( + pixel_values.shape + ) + reshaped_pixel_values = pixel_values.view( + batch_size * num_patches, num_channels, height, width + ) + image_features = self.vision_tower( + reshaped_pixel_values, output_hidden_states=True + ) + + selected_image_feature = image_features.hidden_states[ + vision_feature_layer + ] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + + inputs_embeds, attention_mask, labels, position_ids = ( + self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + ) + if labels is None: + labels = torch.full_like( + attention_mask, self.config.ignore_index + ).to(torch.long) + + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif ( + past_key_values is not None + and pixel_values is not None + and input_ids.shape[1] == 1 + ): + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where( + first_layer_past_key_value.float().sum(-2) == 0 + ) + + # Get the target length + target_seqlen = first_layer_past_key_value.shape[-1] + 1 + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat( + (attention_mask, extended_attention_mask), dim=1 + ) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + ) + + logits = outputs[0] + return logits diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 30bf4aa6..1a63934f 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -17,7 +17,7 @@ from transformers import LlamaTokenizerFast from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, ) -from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM +from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -25,7 +25,7 @@ from text_generation_server.utils import ( ) -class IDEFICSSharded(IdeficsCausalLM): +class IDEFICSSharded(VlmCausalLM): def __init__( self, model_id: str, @@ -82,7 +82,7 @@ class IDEFICSSharded(IdeficsCausalLM): model = IdeficsForVisionText2Text(config, weights) torch.distributed.barrier(group=self.process_group) - super(IdeficsCausalLM, self).__init__( + super(VlmCausalLM, self).__init__( model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py new file mode 100644 index 00000000..9e35d0fc --- /dev/null +++ b/server/text_generation_server/models/llava_next.py @@ -0,0 +1,89 @@ +import torch +import torch.distributed + +from typing import List, Optional, Tuple + +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoProcessor, +) +from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, +) + +# from transformers import AutoConfig, AutoTokenizer, AutoProcessor +from text_generation_server.models.vlm_causal_lm import VlmCausalLM +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + + +class LlavaNext(VlmCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + # 9b seems to work correctly enough in float16, but 80b seems + # to be really saturating for f16. + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + self.device, self.dtype = device, dtype + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.use_medusa = use_medusa + config.vision_config.quantize = quantize + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + ) + + model = LlavaNextForConditionalGeneration(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(VlmCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py similarity index 92% rename from server/text_generation_server/models/idefics_causal_lm.py rename to server/text_generation_server/models/vlm_causal_lm.py index c96e8152..a7cfc3e4 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -47,7 +47,7 @@ tracer = trace.get_tracer(__name__) @dataclass -class IdeficsCausalLMBatch(Batch): +class VlmCausalLMBatch(Batch): batch_id: int requests: List[generate_pb2.Request] requests_idx_mapping: Dict[int, int] @@ -99,7 +99,7 @@ class IdeficsCausalLMBatch(Batch): processor: ProcessorMixin, # Hack dtype: torch.dtype, device: torch.device, - ) -> "IdeficsCausalLMBatch": + ) -> "VlmCausalLMBatch": inputs = [] next_token_choosers = [] stopping_criterias = [] @@ -127,21 +127,23 @@ class IdeficsCausalLMBatch(Batch): padding_right_offset, stopping_criteria.max_new_tokens ) - prompts = [] - for inp in inputs: - # Each input is encoded into a list, where each element of this input list is either a string or a URL - prompts.append(split(inp)) + # TODO Check impact on idefics + # prompts = [] + # for inp in inputs: + # # Each input is encoded into a list, where each element of this input list is either a string or a URL + # prompts.append(split(inp)) # The processor replaces the call to tokenizer, and # a/ takes care of fetching images from the URL # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model tokenized_inputs = processor( - prompts, + inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_truncation, - add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token + # TODO Check impact on idefics + # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] @@ -156,7 +158,7 @@ class IdeficsCausalLMBatch(Batch): max_input_length = input_lengths.max() input_ids = tokenized_inputs["input_ids"] - pixel_values = tokenized_inputs["pixel_values"] + pixel_values = tokenized_inputs.get("pixel_values", None) image_hidden_states = None # Allocate maximum attention_mask attention_mask = input_ids.new_zeros( @@ -165,16 +167,19 @@ class IdeficsCausalLMBatch(Batch): # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] # Do the same for image_attention_mask - image_attention_mask = input_ids.new_zeros( - ( - pb.size, - max_input_length + padding_right_offset, - tokenized_inputs["pixel_values"].size(1), + if pixel_values is None: + image_attention_mask = None + else: + image_attention_mask = input_ids.new_zeros( + ( + pb.size, + max_input_length + padding_right_offset, + pixel_values.size(1), + ) ) - ) - image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ - "image_attention_mask" - ] + image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ + "image_attention_mask" + ] position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) @@ -207,7 +212,7 @@ class IdeficsCausalLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: + def filter(self, request_ids: List[int]) -> Optional["VlmCausalLMBatch"]: # It deletes requests from the batch. For instance when client lost connection if len(request_ids) == 0: raise ValueError("Batch must have at least one request") @@ -323,9 +328,7 @@ class IdeficsCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["IdeficsCausalLMBatch"] - ) -> "IdeficsCausalLMBatch": + def concatenate(cls, batches: List["VlmCausalLMBatch"]) -> "VlmCausalLMBatch": # It adds new requests to the batch # Used for padding total_batch_size = 0 @@ -564,7 +567,7 @@ class IdeficsCausalLMBatch(Batch): return len(self.requests) -class IdeficsCausalLM(Model): +class VlmCausalLM(Model): def __init__( self, model_id: str, @@ -574,7 +577,7 @@ class IdeficsCausalLM(Model): trust_remote_code: bool = False, ): from text_generation_server.models.custom_modeling.idefics_modeling import ( - IdeficsForVisionText2Text, + VlmForVisionText2Text, ) if torch.cuda.is_available(): @@ -601,7 +604,7 @@ class IdeficsCausalLM(Model): truncation_side="left", trust_remote_code=trust_remote_code, ) - model = IdeficsForVisionText2Text.from_pretrained( + model = VlmForVisionText2Text.from_pretrained( model_id, revision=revision, torch_dtype=dtype, @@ -626,7 +629,7 @@ class IdeficsCausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": ""}) - super(IdeficsCausalLM, self).__init__( + super(VlmCausalLM, self).__init__( model=model, tokenizer=tokenizer, requires_padding=True, @@ -635,8 +638,8 @@ class IdeficsCausalLM(Model): ) @property - def batch_type(self) -> Type[IdeficsCausalLMBatch]: - return IdeficsCausalLMBatch + def batch_type(self) -> Type[VlmCausalLMBatch]: + return VlmCausalLMBatch def forward( self, @@ -672,24 +675,27 @@ class IdeficsCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: IdeficsCausalLMBatch - ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: + self, batch: VlmCausalLMBatch + ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]: start = time.time_ns() # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - if batch.input_ids.size(1) == 1: - # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), - # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension - # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated - # token need to attend to the encoder hidden states (i.e. the vision encoder) - # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic - image_attention_mask = batch.image_attention_mask[ - :, -(batch.padding_right_offset + 1) - ].unsqueeze(1) + if batch.image_attention_mask is None: + image_attention_mask = None else: - image_attention_mask = batch.image_attention_mask[ - :, : -batch.padding_right_offset - ] + if batch.input_ids.size(1) == 1: + # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), + # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension + # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated + # token need to attend to the encoder hidden states (i.e. the vision encoder) + # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic + image_attention_mask = batch.image_attention_mask[ + :, -(batch.padding_right_offset + 1) + ].unsqueeze(1) + else: + image_attention_mask = batch.image_attention_mask[ + :, : -batch.padding_right_offset + ] logits, speculative_logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d5adbd32..308fb1e2 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -15,7 +15,7 @@ from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch +from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): @@ -79,7 +79,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): pass if ( - self.model.batch_type == IdeficsCausalLMBatch + self.model.batch_type == VlmCausalLMBatch ): # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb( request.batch, @@ -101,7 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Prefill(self, request, context): start = time.time_ns() if ( - self.model.batch_type == IdeficsCausalLMBatch + self.model.batch_type == VlmCausalLMBatch ): # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb( request.batch,