More work on the CLIP Side.

This commit is contained in:
Nicolas Patry 2024-04-04 18:08:38 +00:00
parent b8be0d1ae7
commit 5f4b395480
8 changed files with 429 additions and 273 deletions

View File

@ -10,16 +10,23 @@ from transformers.modeling_attn_mask_utils import (
) )
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from text_generation_server.utils.layers import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
class CLIPVisionEmbeddings(nn.Module): class CLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig): def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.image_size = config.image_size self.image_size = config.image_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) # TODO Should we TP this ?
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels, in_channels=config.num_channels,
@ -28,13 +35,18 @@ class CLIPVisionEmbeddings(nn.Module):
stride=self.patch_size, stride=self.patch_size,
bias=False, bias=False,
) )
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
self.register_buffer( self.register_buffer(
"position_ids", "position_ids",
torch.arange(self.num_positions).expand((1, -1)), torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
persistent=False, persistent=False,
) )
@ -94,28 +106,38 @@ class CLIPTextEmbeddings(nn.Module):
class CLIPAttention(nn.Module): class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_size = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim: if self.head_size * self.num_heads != self.embed_dim:
raise ValueError( raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})." f" {self.num_heads})."
) )
self.scale = self.head_dim**-0.5 self.num_heads = self.num_heads // weights.process_group.size()
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.qkv = TensorParallelColumnLinear.load_multi(
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) config,
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return ( return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim) tensor.view(bsz, seq_len, self.num_heads, self.head_size)
.transpose(1, 2) .transpose(1, 2)
.contiguous() .contiguous()
) )
@ -132,11 +154,20 @@ class CLIPAttention(nn.Module):
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj # get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
]
* 3,
dim=2,
)
query_states = query_states * self.scale
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_size)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.view(*proj_shape)
@ -176,48 +207,38 @@ class CLIPAttention(nn.Module):
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout( attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training attn_weights, p=self.dropout, training=self.training
) )
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
f" {attn_output.size()}" f" {attn_output.size()}"
) )
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped return attn_output, None
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
def __init__(self, config): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = ACT2FN[config.hidden_act] self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc1 = TensorParallelColumnLinear.load(
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
@ -227,13 +248,19 @@ class CLIPMLP(nn.Module):
class CLIPEncoderLayer(nn.Module): class CLIPEncoderLayer(nn.Module):
def __init__(self, config: CLIPConfig): def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(config) self.self_attn = CLIPAttention(
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) prefix=f"{prefix}.self_attn", config=config, weights=weights
self.mlp = CLIPMLP(config) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
)
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
)
def forward( def forward(
self, self,
@ -281,73 +308,6 @@ class CLIPPreTrainedModel(nn.Module):
base_model_prefix = "clip" base_model_prefix = "clip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, CLIPTextEmbeddings):
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
elif isinstance(module, CLIPVisionEmbeddings):
factor = self.config.initializer_factor
nn.init.normal_(
module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor
)
nn.init.normal_(
module.patch_embedding.weight,
std=module.config.initializer_range * factor,
)
nn.init.normal_(
module.position_embedding.weight,
std=module.config.initializer_range * factor,
)
elif isinstance(module, CLIPAttention):
factor = self.config.initializer_factor
in_proj_std = (
(module.embed_dim**-0.5)
* ((2 * module.config.num_hidden_layers) ** -0.5)
* factor
)
out_proj_std = (module.embed_dim**-0.5) * factor
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, CLIPMLP):
factor = self.config.initializer_factor
in_proj_std = (
(module.config.hidden_size**-0.5)
* ((2 * module.config.num_hidden_layers) ** -0.5)
* factor
)
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, CLIPModel):
nn.init.normal_(
module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
)
nn.init.normal_(
module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
)
elif isinstance(module, CLIPVisionModelWithProjection):
nn.init.normal_(
module.visual_projection.weight,
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
)
elif isinstance(module, CLIPTextModelWithProjection):
nn.init.normal_(
module.text_projection.weight,
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
)
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
CLIP_START_DOCSTRING = r""" CLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -458,13 +418,17 @@ class CLIPEncoder(nn.Module):
config: CLIPConfig config: CLIPConfig
""" """
def __init__(self, config: CLIPConfig): def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] [
CLIPEncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
) )
self.gradient_checkpointing = False
def forward( def forward(
self, self,
@ -523,29 +487,15 @@ class CLIPEncoder(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: layer_outputs = encoder_layer(
encoder_states = encoder_states + (hidden_states,) hidden_states,
if self.gradient_checkpointing and self.training: attention_mask,
layer_outputs = self._gradient_checkpointing_func( causal_attention_mask,
encoder_layer.__call__, output_attentions=output_attentions,
hidden_states, )
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
return hidden_states return hidden_states
@ -555,7 +505,9 @@ class CLIPTextTransformer(nn.Module):
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config) self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config) self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation # For `pooled_output` computation
@ -710,10 +662,20 @@ class CLIPVisionTransformer(nn.Module):
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(config) self.embeddings = CLIPVisionEmbeddings(
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) prefix=f"{prefix}.embeddings", config=config, weights=weights
self.encoder = CLIPEncoder(config) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.pre_layrnorm = nn.LayerNorm.load(
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
)
self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward( def forward(
self, self,

View File

@ -385,7 +385,32 @@ class MistralModel(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
return self.with_hidden_states(
hidden_states,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
true_max_s,
prefill_cache_indices,
)
def with_hidden_states(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -409,7 +434,6 @@ class MistralModel(torch.nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states

View File

@ -107,7 +107,9 @@ def load_vision_model(prefix, config, weights):
CLIPVisionTransformer, CLIPVisionTransformer,
) )
return CLIPVisionTransformer(prefix, config, weights) return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")
@ -133,11 +135,13 @@ class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
# self.vision_tower = load_vision_model( self.vision_tower = load_vision_model(
# prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
# ) config=config.vision_config,
weights=weights,
)
# self.multi_modal_projector = LlavaNextMultiModalProjector(config) self.multi_modal_projector = LlavaNextMultiModalProjector(config)
self.image_newline = weights.get_tensor("image_newline") self.image_newline = weights.get_tensor("image_newline")
@ -153,7 +157,6 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.pad_token_id = ( self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
) )
# self.post_init()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
def _merge_input_ids_with_image_features( def _merge_input_ids_with_image_features(
@ -278,118 +281,102 @@ class LlavaNextForConditionalGeneration(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
): ):
if pixel_values is not None and len(pixel_values) > 0:
num_special_image_tokens = (
input_ids == self.config.image_token_index
).sum()
assert num_special_image_tokens == len(
pixel_values
), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
# vision_feature_layer = ( # 2. Merge text and images
# vision_feature_layer num_images, num_patches, channels, height, width = pixel_values.shape
# if vision_feature_layer is not None pixel_values = pixel_values.view(
# else self.config.vision_feature_layer num_images * num_patches, channels, height, width
# ) )
# vision_feature_select_strategy = ( image_features = self.vision_tower(pixel_values)
# vision_feature_select_strategy
# if vision_feature_select_strategy is not None
# else self.config.vision_feature_select_strategy
# )
# if cu_seqlen_prefill is not None: selected_image_feature = image_features.hidden_states[
# pass self.config.vision_feature_layer
# # # 1. Extract the input embeddings ]
# # inputs_embeds = self.get_input_embeddings()(input_ids)
# # # 2. Merge text and images if self.config.vision_feature_select_strategy == "default":
# # if pixel_values is not None and input_ids.shape[1] != 1: selected_image_feature = selected_image_feature[:, 1:]
# # batch_size, num_patches, num_channels, height, width = ( elif self.config.vision_feature_select_strategy == "full":
# # pixel_values.shape selected_image_feature = selected_image_feature
# # ) else:
# # reshaped_pixel_values = pixel_values.view( raise RuntimeError(
# # batch_size * num_patches, num_channels, height, width f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
# # ) )
# # image_features = self.vision_tower(
# # reshaped_pixel_values, output_hidden_states=True
# # )
# # selected_image_feature = image_features.hidden_states[ image_features = self.multi_modal_projector(selected_image_feature)
# # vision_feature_layer
# # ]
# # if vision_feature_select_strategy == "default": # split up image_features for each of the individual images
# # selected_image_feature = selected_image_feature[:, 1:] # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# # elif vision_feature_select_strategy == "full": # if we assume each image has 5 image features (base image + 4 patches)
# # selected_image_feature = selected_image_feature split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# # image_features = self.multi_modal_projector(selected_image_feature) # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
# # # split up image_features for each of the individual images new_image_features = []
# # # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) for image_idx, image_feature in enumerate(image_features):
# # # if we assume each image has 5 image features (base image + 4 patches) if image_feature.shape[0] > 1:
# # split_sizes = [image.shape[0] for image in pixel_values] base_image_feature = image_feature[0]
# # image_features = torch.split(image_features, split_sizes, dim=0) image_feature = image_feature[1:]
# # # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" if height * width != base_image_feature.shape[0]:
# # height = width = ( raise ValueError(
# # self.config.vision_config.image_size "The number of patches is not consistent with the image size."
# # // self.config.vision_config.patch_size )
# # ) num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
# # new_image_features = [] inputs_embeds, attention_mask, labels, position_ids = (
# # for image_idx, image_feature in enumerate(image_features): self._merge_input_ids_with_image_features(
# # if image_feature.shape[0] > 1: image_features, inputs_embeds, input_ids, attention_mask, labels
# # base_image_feature = image_feature[0] )
# # image_feature = image_feature[1:] )
if labels is None:
# # if height * width != base_image_feature.shape[0]: labels = torch.full_like(attention_mask, self.config.ignore_index).to(
# # raise ValueError( torch.long
# # "The number of patches is not consistent with the image size." )
# # )
# # num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# # image_sizes[image_idx],
# # self.config.image_grid_pinpoints,
# # self.config.vision_config.image_size,
# # )
# # image_feature = image_feature.view(
# # num_patch_height, num_patch_width, height, width, -1
# # )
# # image_feature = image_feature.permute(
# # 4, 0, 2, 1, 3
# # ).contiguous()
# # image_feature = image_feature.flatten(1, 2).flatten(2, 3)
# # image_feature = unpad_image(
# # image_feature, image_sizes[image_idx]
# # )
# # image_feature = torch.cat(
# # (
# # image_feature,
# # self.image_newline[:, None, None].expand(
# # *image_feature.shape[:-1], 1
# # ),
# # ),
# # dim=-1,
# # )
# # image_feature = image_feature.flatten(1, 2).transpose(0, 1)
# # image_feature = torch.cat(
# # (base_image_feature, image_feature), dim=0
# # )
# # else:
# # image_feature = image_feature[0]
# # image_feature = torch.cat(
# # (image_feature, self.image_newline[None]), dim=0
# # )
# # new_image_features.append(image_feature)
# # image_features = torch.stack(new_image_features, dim=0)
# # inputs_embeds, attention_mask, labels, position_ids = (
# # self._merge_input_ids_with_image_features(
# # image_features, inputs_embeds, input_ids, attention_mask, labels
# # )
# # )
# # if labels is None:
# # labels = torch.full_like(
# # attention_mask, self.config.ignore_index
# # ).to(torch.long)
logits = self.language_model( logits = self.language_model(
input_ids, input_ids,

View File

@ -106,6 +106,19 @@ class FlashCausalLMBatch(Batch):
max_tokens=self.blocks * BLOCK_SIZE, max_tokens=self.blocks * BLOCK_SIZE,
) )
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer):
batch_inputs = []
max_truncation = 0
for r in requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
return batch_tokenized_inputs
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
@ -114,16 +127,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_inputs = [] batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = [] position_ids = []
speculative_ids = [] speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]

View File

@ -65,19 +65,21 @@ class FlashMistralBatch(FlashCausalLMBatch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
sliding_window, sliding_window_blocks = get_sliding_windows() sliding_window, sliding_window_blocks = get_sliding_windows()
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = [] position_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
needed_blocks_slots = [] needed_blocks_slots = []

View File

@ -927,7 +927,7 @@ class IdeficsCausalLMBatch(Batch):
) )
@classmethod @classmethod
def from_pb( def from_pb_processor(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,

View File

@ -1,12 +1,18 @@
import re import re
import torch
from opentelemetry import trace from opentelemetry import trace
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_mistral import ( from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
FlashMistralBatch, FlashMistralBatch,
) )
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -31,13 +37,65 @@ def split(string) -> List[Dict[str, str]]:
class VlmCausalLMBatch(FlashMistralBatch): class VlmCausalLMBatch(FlashMistralBatch):
pass pixel_values: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor):
batch_inputs = []
images = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
elif chunk["type"] == "image":
full_text += "<image>"
images.append(chunk["content"])
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
images = processor.image_processor.fetch_images(images)
if images:
image_inputs = processor.image_processor(images, return_tensors="pt")
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.pixel_values = None
batch.image_sizes = None
return batch
class VlmCausalLM(BaseFlashMistral): class VlmCausalLM(BaseFlashMistral):
@property @property
def batch_type(self) -> Type[FlashMistralBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return FlashMistralBatch return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]: def get_layer_config(self, model) -> Tuple[int, int, int]:
return ( return (
@ -48,3 +106,122 @@ class VlmCausalLM(BaseFlashMistral):
def max_past(self) -> Optional[int]: def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None) return getattr(self.model.language_model, "max_past", None)
def forward(
self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits

View File

@ -83,7 +83,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,
@ -106,7 +106,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,