diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 856c642c..96bbc9aa 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -10,16 +10,23 @@ from transformers.modeling_attn_mask_utils import ( ) from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + class CLIPVisionEmbeddings(nn.Module): - def __init__(self, config: CLIPVisionConfig): + def __init__(self, prefix, config: CLIPVisionConfig, weights): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size - self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + # TODO Should we TP this ? + self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding") self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, @@ -28,13 +35,18 @@ class CLIPVisionEmbeddings(nn.Module): stride=self.patch_size, bias=False, ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) self.register_buffer( "position_ids", - torch.arange(self.num_positions).expand((1, -1)), + torch.arange(self.num_positions, device=weights.device).expand((1, -1)), persistent=False, ) @@ -94,28 +106,38 @@ class CLIPTextEmbeddings(nn.Module): class CLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, prefix, config, weights): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) - self.scale = self.head_dim**-0.5 + self.num_heads = self.num_heads // weights.process_group.size() + self.scale = self.head_size**-0.5 self.dropout = config.attention_dropout - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=False, + ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + tensor.view(bsz, seq_len, self.num_heads, self.head_size) .transpose(1, 2) .contiguous() ) @@ -132,11 +154,20 @@ class CLIPAttention(nn.Module): bsz, tgt_len, embed_dim = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + ] + * 3, + dim=2, + ) + query_states = query_states * self.scale + key_states = self._shape(key_states, -1, bsz) + value_states = self._shape(value_states, -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_size) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -176,48 +207,38 @@ class CLIPAttention(nn.Module): attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if output_attentions: - # this operation is a bit akward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) - else: - attn_weights_reshaped = None - attn_probs = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) attn_output = torch.bmm(attn_probs, value_states) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped + return attn_output, None class CLIPMLP(nn.Module): - def __init__(self, config): + def __init__(self, prefix, config, weights): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) @@ -227,13 +248,19 @@ class CLIPMLP(nn.Module): class CLIPEncoderLayer(nn.Module): - def __init__(self, config: CLIPConfig): + def __init__(self, prefix, config: CLIPConfig, weights): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = CLIPAttention(config) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = CLIPMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = CLIPAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps + ) + self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps + ) def forward( self, @@ -281,73 +308,6 @@ class CLIPPreTrainedModel(nn.Module): base_model_prefix = "clip" supports_gradient_checkpointing = True - def _init_weights(self, module): - """Initialize the weights""" - factor = self.config.initializer_factor - if isinstance(module, CLIPTextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - elif isinstance(module, CLIPVisionEmbeddings): - factor = self.config.initializer_factor - nn.init.normal_( - module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor - ) - nn.init.normal_( - module.patch_embedding.weight, - std=module.config.initializer_range * factor, - ) - nn.init.normal_( - module.position_embedding.weight, - std=module.config.initializer_range * factor, - ) - elif isinstance(module, CLIPAttention): - factor = self.config.initializer_factor - in_proj_std = ( - (module.embed_dim**-0.5) - * ((2 * module.config.num_hidden_layers) ** -0.5) - * factor - ) - out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) - elif isinstance(module, CLIPMLP): - factor = self.config.initializer_factor - in_proj_std = ( - (module.config.hidden_size**-0.5) - * ((2 * module.config.num_hidden_layers) ** -0.5) - * factor - ) - fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) - elif isinstance(module, CLIPModel): - nn.init.normal_( - module.text_projection.weight, - std=module.text_embed_dim**-0.5 * self.config.initializer_factor, - ) - nn.init.normal_( - module.visual_projection.weight, - std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, - ) - elif isinstance(module, CLIPVisionModelWithProjection): - nn.init.normal_( - module.visual_projection.weight, - std=self.config.hidden_size**-0.5 * self.config.initializer_factor, - ) - elif isinstance(module, CLIPTextModelWithProjection): - nn.init.normal_( - module.text_projection.weight, - std=self.config.hidden_size**-0.5 * self.config.initializer_factor, - ) - - if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - CLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -458,13 +418,17 @@ class CLIPEncoder(nn.Module): config: CLIPConfig """ - def __init__(self, config: CLIPConfig): + def __init__(self, prefix, config: CLIPConfig, weights): super().__init__() self.config = config self.layers = nn.ModuleList( - [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + [ + CLIPEncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] ) - self.gradient_checkpointing = False def forward( self, @@ -523,29 +487,15 @@ class CLIPEncoder(nn.Module): hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - return hidden_states @@ -555,7 +505,9 @@ class CLIPTextTransformer(nn.Module): self.config = config embed_dim = config.hidden_size self.embeddings = CLIPTextEmbeddings(config) - self.encoder = CLIPEncoder(config) + self.encoder = CLIPEncoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) # For `pooled_output` computation @@ -710,10 +662,20 @@ class CLIPVisionTransformer(nn.Module): self.config = config embed_dim = config.hidden_size - self.embeddings = CLIPVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.encoder = CLIPEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.embeddings = CLIPVisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.pre_layrnorm = nn.LayerNorm.load( + prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps + ) + self.encoder = CLIPEncoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 6631aba4..352907b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -385,7 +385,32 @@ class MistralModel(torch.nn.Module): prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + return self.with_hidden_states( + hidden_states, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + true_max_s, + prefill_cache_indices, + ) + def with_hidden_states( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ): # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -409,7 +434,6 @@ class MistralModel(torch.nn.Module): ) hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 8de11944..c5424a49 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -107,7 +107,9 @@ def load_vision_model(prefix, config, weights): CLIPVisionTransformer, ) - return CLIPVisionTransformer(prefix, config, weights) + return CLIPVisionTransformer( + prefix=f"{prefix}.vision_model", config=config, weights=weights + ) else: raise RuntimeError(f"Unsupported model type {config.model_type}") @@ -133,11 +135,13 @@ class LlavaNextForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = config.quantize - # self.vision_tower = load_vision_model( - # prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights - # ) + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) - # self.multi_modal_projector = LlavaNextMultiModalProjector(config) + self.multi_modal_projector = LlavaNextMultiModalProjector(config) self.image_newline = weights.get_tensor("image_newline") @@ -153,7 +157,6 @@ class LlavaNextForConditionalGeneration(nn.Module): self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) - # self.post_init() # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features def _merge_input_ids_with_image_features( @@ -278,118 +281,102 @@ class LlavaNextForConditionalGeneration(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, ): + if pixel_values is not None and len(pixel_values) > 0: + num_special_image_tokens = ( + input_ids == self.config.image_token_index + ).sum() + assert num_special_image_tokens == len( + pixel_values + ), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + inputs_embeds = self.language_model.model.embed_tokens(input_ids) - # vision_feature_layer = ( - # vision_feature_layer - # if vision_feature_layer is not None - # else self.config.vision_feature_layer - # ) - # vision_feature_select_strategy = ( - # vision_feature_select_strategy - # if vision_feature_select_strategy is not None - # else self.config.vision_feature_select_strategy - # ) + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) - # if cu_seqlen_prefill is not None: - # pass - # # # 1. Extract the input embeddings - # # inputs_embeds = self.get_input_embeddings()(input_ids) + selected_image_feature = image_features.hidden_states[ + self.config.vision_feature_layer + ] - # # # 2. Merge text and images - # # if pixel_values is not None and input_ids.shape[1] != 1: - # # batch_size, num_patches, num_channels, height, width = ( - # # pixel_values.shape - # # ) - # # reshaped_pixel_values = pixel_values.view( - # # batch_size * num_patches, num_channels, height, width - # # ) - # # image_features = self.vision_tower( - # # reshaped_pixel_values, output_hidden_states=True - # # ) + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) - # # selected_image_feature = image_features.hidden_states[ - # # vision_feature_layer - # # ] + image_features = self.multi_modal_projector(selected_image_feature) - # # if vision_feature_select_strategy == "default": - # # selected_image_feature = selected_image_feature[:, 1:] - # # elif vision_feature_select_strategy == "full": - # # selected_image_feature = selected_image_feature + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) - # # image_features = self.multi_modal_projector(selected_image_feature) + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) - # # # split up image_features for each of the individual images - # # # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # # # if we assume each image has 5 image features (base image + 4 patches) - # # split_sizes = [image.shape[0] for image in pixel_values] - # # image_features = torch.split(image_features, split_sizes, dim=0) + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] - # # # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - # # height = width = ( - # # self.config.vision_config.image_size - # # // self.config.vision_config.patch_size - # # ) + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) - # # new_image_features = [] - # # for image_idx, image_feature in enumerate(image_features): - # # if image_feature.shape[0] > 1: - # # base_image_feature = image_feature[0] - # # image_feature = image_feature[1:] - - # # if height * width != base_image_feature.shape[0]: - # # raise ValueError( - # # "The number of patches is not consistent with the image size." - # # ) - # # num_patch_height, num_patch_width = get_anyres_image_grid_shape( - # # image_sizes[image_idx], - # # self.config.image_grid_pinpoints, - # # self.config.vision_config.image_size, - # # ) - # # image_feature = image_feature.view( - # # num_patch_height, num_patch_width, height, width, -1 - # # ) - # # image_feature = image_feature.permute( - # # 4, 0, 2, 1, 3 - # # ).contiguous() - # # image_feature = image_feature.flatten(1, 2).flatten(2, 3) - # # image_feature = unpad_image( - # # image_feature, image_sizes[image_idx] - # # ) - # # image_feature = torch.cat( - # # ( - # # image_feature, - # # self.image_newline[:, None, None].expand( - # # *image_feature.shape[:-1], 1 - # # ), - # # ), - # # dim=-1, - # # ) - # # image_feature = image_feature.flatten(1, 2).transpose(0, 1) - # # image_feature = torch.cat( - # # (base_image_feature, image_feature), dim=0 - # # ) - # # else: - # # image_feature = image_feature[0] - # # image_feature = torch.cat( - # # (image_feature, self.image_newline[None]), dim=0 - # # ) - # # new_image_features.append(image_feature) - # # image_features = torch.stack(new_image_features, dim=0) - - # # inputs_embeds, attention_mask, labels, position_ids = ( - # # self._merge_input_ids_with_image_features( - # # image_features, inputs_embeds, input_ids, attention_mask, labels - # # ) - # # ) - # # if labels is None: - # # labels = torch.full_like( - # # attention_mask, self.config.ignore_index - # # ).to(torch.long) + inputs_embeds, attention_mask, labels, position_ids = ( + self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + ) + if labels is None: + labels = torch.full_like(attention_mask, self.config.ignore_index).to( + torch.long + ) logits = self.language_model( input_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 22c2fa6c..94a7f023 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -106,6 +106,19 @@ class FlashCausalLMBatch(Batch): max_tokens=self.blocks * BLOCK_SIZE, ) + @classmethod + def batch_tokenized_inputs(cls, requests, tokenizer): + batch_inputs = [] + max_truncation = 0 + for r in requests: + batch_inputs.append(r.inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + return batch_tokenized_inputs + @classmethod def from_pb( cls, @@ -114,16 +127,7 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - batch_inputs = [] - max_truncation = 0 - for r in pb.requests: - batch_inputs.append(r.inputs) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] - + batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) position_ids = [] speculative_ids = [] cu_seqlen_prefill = [0] diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8806942a..6bbe7d88 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -65,19 +65,21 @@ class FlashMistralBatch(FlashCausalLMBatch): tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, + ) -> "FlashCausalLMBatch": + batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) + return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + + @classmethod + def from_tokenized( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + batch_tokenized_inputs, + dtype: torch.dtype, + device: torch.device, ) -> "FlashCausalLMBatch": sliding_window, sliding_window_blocks = get_sliding_windows() - batch_inputs = [] - max_truncation = 0 - for r in pb.requests: - batch_inputs.append(r.inputs) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] - position_ids = [] cu_seqlen_prefill = [0] needed_blocks_slots = [] diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index d64e047b..211b425d 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -927,7 +927,7 @@ class IdeficsCausalLMBatch(Batch): ) @classmethod - def from_pb( + def from_pb_processor( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 00088a4d..a04fc4d4 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,12 +1,18 @@ import re +import torch from opentelemetry import trace from typing import Optional, Tuple, List, Type, Dict +from transformers import PreTrainedTokenizerBase +from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_mistral import ( BaseFlashMistral, FlashMistralBatch, ) +from text_generation_server.models.cache_manager import ( + get_cache_manager, +) tracer = trace.get_tracer(__name__) @@ -31,13 +37,65 @@ def split(string) -> List[Dict[str, str]]: class VlmCausalLMBatch(FlashMistralBatch): - pass + pixel_values: Optional[List[torch.Tensor]] + image_sizes: Optional[List[Tuple[int, int]]] + + @classmethod + def batch_tokenized_inputs(cls, requests, tokenizer, processor): + batch_inputs = [] + images = [] + max_truncation = 0 + for r in requests: + chunks = split(r.inputs) + full_text = "" + for chunk in chunks: + if chunk["type"] == "text": + full_text += chunk["content"] + elif chunk["type"] == "image": + full_text += "" + images.append(chunk["content"]) + else: + raise RuntimeError(f"Invalid chunk type {chunk['type']}") + + batch_inputs.append(full_text) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + images = processor.image_processor.fetch_images(images) + if images: + image_inputs = processor.image_processor(images, return_tensors="pt") + else: + image_inputs = None + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + dtype: torch.dtype, + device: torch.device, + ) -> "VlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to(device=device) + batch.image_sizes = image_inputs["image_sizes"].to(device=device) + else: + batch.pixel_values = None + batch.image_sizes = None + return batch class VlmCausalLM(BaseFlashMistral): @property - def batch_type(self) -> Type[FlashMistralBatch]: - return FlashMistralBatch + def batch_type(self) -> Type[VlmCausalLMBatch]: + return VlmCausalLMBatch def get_layer_config(self, model) -> Tuple[int, int, int]: return ( @@ -48,3 +106,122 @@ class VlmCausalLM(BaseFlashMistral): def max_past(self) -> Optional[int]: return getattr(self.model.language_model, "max_past", None) + + def forward( + self, batch: VlmCausalLMBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + padded_bs = bs + if bs == 3: + padded_bs = 4 + elif 3 < bs <= 8: + padded_bs = 8 + elif bs > 8: + padded_bs = (bs + 7) // 8 * 8 + + # Try to find an associated cuda graph + cuda_graph = self.cuda_graphs.get(padded_bs, None) + + if cu_seqlen_prefill is not None or cuda_graph is None: + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + image_sizes=batch.image_sizes, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"].fill_(-1) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index beca36b9..cb9b302d 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -83,7 +83,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): IdeficsCausalLMBatch, VlmCausalLMBatch, }: # Hack, i would rather use kwargs in the `from_pb` call - batch = self.model.batch_type.from_pb( + batch = self.model.batch_type.from_pb_processor( request.batch, self.model.tokenizer, self.model.processor, @@ -106,7 +106,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): IdeficsCausalLMBatch, VlmCausalLMBatch, }: # Hack, i would rather use kwargs in the `from_pb` call - batch = self.model.batch_type.from_pb( + batch = self.model.batch_type.from_pb_processor( request.batch, self.model.tokenizer, self.model.processor,