mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
More work on the CLIP Side.
This commit is contained in:
parent
b8be0d1ae7
commit
5f4b395480
@ -10,16 +10,23 @@ from transformers.modeling_attn_mask_utils import (
|
|||||||
)
|
)
|
||||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionEmbeddings(nn.Module):
|
class CLIPVisionEmbeddings(nn.Module):
|
||||||
def __init__(self, config: CLIPVisionConfig):
|
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.image_size = config.image_size
|
self.image_size = config.image_size
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
# TODO Should we TP this ?
|
||||||
|
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
|
||||||
|
|
||||||
self.patch_embedding = nn.Conv2d(
|
self.patch_embedding = nn.Conv2d(
|
||||||
in_channels=config.num_channels,
|
in_channels=config.num_channels,
|
||||||
@ -28,13 +35,18 @@ class CLIPVisionEmbeddings(nn.Module):
|
|||||||
stride=self.patch_size,
|
stride=self.patch_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
self.num_positions = self.num_patches + 1
|
self.num_positions = self.num_patches + 1
|
||||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"position_ids",
|
"position_ids",
|
||||||
torch.arange(self.num_positions).expand((1, -1)),
|
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,28 +106,38 @@ class CLIPTextEmbeddings(nn.Module):
|
|||||||
class CLIPAttention(nn.Module):
|
class CLIPAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_size = self.embed_dim // self.num_heads
|
||||||
if self.head_dim * self.num_heads != self.embed_dim:
|
if self.head_size * self.num_heads != self.embed_dim:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
f" {self.num_heads})."
|
f" {self.num_heads})."
|
||||||
)
|
)
|
||||||
self.scale = self.head_dim**-0.5
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.scale = self.head_size**-0.5
|
||||||
self.dropout = config.attention_dropout
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
config,
|
||||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return (
|
return (
|
||||||
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
|
||||||
.transpose(1, 2)
|
.transpose(1, 2)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
@ -132,11 +154,20 @@ class CLIPAttention(nn.Module):
|
|||||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states) * self.scale
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
|
|
||||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
]
|
||||||
|
* 3,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
query_states = query_states * self.scale
|
||||||
|
key_states = self._shape(key_states, -1, bsz)
|
||||||
|
value_states = self._shape(value_states, -1, bsz)
|
||||||
|
|
||||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
||||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||||
key_states = key_states.view(*proj_shape)
|
key_states = key_states.view(*proj_shape)
|
||||||
value_states = value_states.view(*proj_shape)
|
value_states = value_states.view(*proj_shape)
|
||||||
@ -176,48 +207,38 @@ class CLIPAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
# this operation is a bit akward, but it's required to
|
|
||||||
# make sure that attn_weights keeps its gradient.
|
|
||||||
# In order to do so, attn_weights have to reshaped
|
|
||||||
# twice and have to be reused in the following
|
|
||||||
attn_weights_reshaped = attn_weights.view(
|
|
||||||
bsz, self.num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights_reshaped.view(
|
|
||||||
bsz * self.num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_weights_reshaped = None
|
|
||||||
|
|
||||||
attn_probs = nn.functional.dropout(
|
attn_probs = nn.functional.dropout(
|
||||||
attn_weights, p=self.dropout, training=self.training
|
attn_weights, p=self.dropout, training=self.training
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
||||||
f" {attn_output.size()}"
|
f" {attn_output.size()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
class CLIPMLP(nn.Module):
|
class CLIPMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_fn = ACT2FN[config.hidden_act]
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.fc1(hidden_states)
|
hidden_states = self.fc1(hidden_states)
|
||||||
@ -227,13 +248,19 @@ class CLIPMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPEncoderLayer(nn.Module):
|
class CLIPEncoderLayer(nn.Module):
|
||||||
def __init__(self, config: CLIPConfig):
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.self_attn = CLIPAttention(config)
|
self.self_attn = CLIPAttention(
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
self.mlp = CLIPMLP(config)
|
)
|
||||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -281,73 +308,6 @@ class CLIPPreTrainedModel(nn.Module):
|
|||||||
base_model_prefix = "clip"
|
base_model_prefix = "clip"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
"""Initialize the weights"""
|
|
||||||
factor = self.config.initializer_factor
|
|
||||||
if isinstance(module, CLIPTextEmbeddings):
|
|
||||||
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
|
||||||
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
|
||||||
elif isinstance(module, CLIPVisionEmbeddings):
|
|
||||||
factor = self.config.initializer_factor
|
|
||||||
nn.init.normal_(
|
|
||||||
module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor
|
|
||||||
)
|
|
||||||
nn.init.normal_(
|
|
||||||
module.patch_embedding.weight,
|
|
||||||
std=module.config.initializer_range * factor,
|
|
||||||
)
|
|
||||||
nn.init.normal_(
|
|
||||||
module.position_embedding.weight,
|
|
||||||
std=module.config.initializer_range * factor,
|
|
||||||
)
|
|
||||||
elif isinstance(module, CLIPAttention):
|
|
||||||
factor = self.config.initializer_factor
|
|
||||||
in_proj_std = (
|
|
||||||
(module.embed_dim**-0.5)
|
|
||||||
* ((2 * module.config.num_hidden_layers) ** -0.5)
|
|
||||||
* factor
|
|
||||||
)
|
|
||||||
out_proj_std = (module.embed_dim**-0.5) * factor
|
|
||||||
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
|
||||||
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
|
||||||
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
|
||||||
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
|
||||||
elif isinstance(module, CLIPMLP):
|
|
||||||
factor = self.config.initializer_factor
|
|
||||||
in_proj_std = (
|
|
||||||
(module.config.hidden_size**-0.5)
|
|
||||||
* ((2 * module.config.num_hidden_layers) ** -0.5)
|
|
||||||
* factor
|
|
||||||
)
|
|
||||||
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
|
||||||
nn.init.normal_(module.fc1.weight, std=fc_std)
|
|
||||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
|
||||||
elif isinstance(module, CLIPModel):
|
|
||||||
nn.init.normal_(
|
|
||||||
module.text_projection.weight,
|
|
||||||
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
|
||||||
)
|
|
||||||
nn.init.normal_(
|
|
||||||
module.visual_projection.weight,
|
|
||||||
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
|
||||||
)
|
|
||||||
elif isinstance(module, CLIPVisionModelWithProjection):
|
|
||||||
nn.init.normal_(
|
|
||||||
module.visual_projection.weight,
|
|
||||||
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
|
||||||
)
|
|
||||||
elif isinstance(module, CLIPTextModelWithProjection):
|
|
||||||
nn.init.normal_(
|
|
||||||
module.text_projection.weight,
|
|
||||||
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
|
|
||||||
|
|
||||||
CLIP_START_DOCSTRING = r"""
|
CLIP_START_DOCSTRING = r"""
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
@ -458,13 +418,17 @@ class CLIPEncoder(nn.Module):
|
|||||||
config: CLIPConfig
|
config: CLIPConfig
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: CLIPConfig):
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
[
|
||||||
|
CLIPEncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -523,29 +487,15 @@ class CLIPEncoder(nn.Module):
|
|||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
layer_outputs = encoder_layer(
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
hidden_states,
|
||||||
if self.gradient_checkpointing and self.training:
|
attention_mask,
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
causal_attention_mask,
|
||||||
encoder_layer.__call__,
|
output_attentions=output_attentions,
|
||||||
hidden_states,
|
)
|
||||||
attention_mask,
|
|
||||||
causal_attention_mask,
|
|
||||||
output_attentions,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
causal_attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -555,7 +505,9 @@ class CLIPTextTransformer(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
self.embeddings = CLIPTextEmbeddings(config)
|
self.embeddings = CLIPTextEmbeddings(config)
|
||||||
self.encoder = CLIPEncoder(config)
|
self.encoder = CLIPEncoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
# For `pooled_output` computation
|
# For `pooled_output` computation
|
||||||
@ -710,10 +662,20 @@ class CLIPVisionTransformer(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.embeddings = CLIPVisionEmbeddings(config)
|
self.embeddings = CLIPVisionEmbeddings(
|
||||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
self.encoder = CLIPEncoder(config)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.pre_layrnorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.encoder = CLIPEncoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -385,7 +385,32 @@ class MistralModel(torch.nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
return self.with_hidden_states(
|
||||||
|
hidden_states,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
true_max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_hidden_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
true_max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
@ -409,7 +434,6 @@ class MistralModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,7 +107,9 @@ def load_vision_model(prefix, config, weights):
|
|||||||
CLIPVisionTransformer,
|
CLIPVisionTransformer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CLIPVisionTransformer(prefix, config, weights)
|
return CLIPVisionTransformer(
|
||||||
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
|
||||||
@ -133,11 +135,13 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
# self.vision_tower = load_vision_model(
|
self.vision_tower = load_vision_model(
|
||||||
# prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
# )
|
config=config.vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
# self.multi_modal_projector = LlavaNextMultiModalProjector(config)
|
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
|
||||||
|
|
||||||
self.image_newline = weights.get_tensor("image_newline")
|
self.image_newline = weights.get_tensor("image_newline")
|
||||||
|
|
||||||
@ -153,7 +157,6 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
self.pad_token_id = (
|
self.pad_token_id = (
|
||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
)
|
)
|
||||||
# self.post_init()
|
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
|
||||||
def _merge_input_ids_with_image_features(
|
def _merge_input_ids_with_image_features(
|
||||||
@ -278,118 +281,102 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
vision_feature_layer: Optional[int] = None,
|
|
||||||
vision_feature_select_strategy: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
num_special_image_tokens = (
|
||||||
|
input_ids == self.config.image_token_index
|
||||||
|
).sum()
|
||||||
|
assert num_special_image_tokens == len(
|
||||||
|
pixel_values
|
||||||
|
), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
|
# 1. Extract the input embeddings
|
||||||
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
# vision_feature_layer = (
|
# 2. Merge text and images
|
||||||
# vision_feature_layer
|
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||||
# if vision_feature_layer is not None
|
pixel_values = pixel_values.view(
|
||||||
# else self.config.vision_feature_layer
|
num_images * num_patches, channels, height, width
|
||||||
# )
|
)
|
||||||
# vision_feature_select_strategy = (
|
image_features = self.vision_tower(pixel_values)
|
||||||
# vision_feature_select_strategy
|
|
||||||
# if vision_feature_select_strategy is not None
|
|
||||||
# else self.config.vision_feature_select_strategy
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if cu_seqlen_prefill is not None:
|
selected_image_feature = image_features.hidden_states[
|
||||||
# pass
|
self.config.vision_feature_layer
|
||||||
# # # 1. Extract the input embeddings
|
]
|
||||||
# # inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
# # # 2. Merge text and images
|
if self.config.vision_feature_select_strategy == "default":
|
||||||
# # if pixel_values is not None and input_ids.shape[1] != 1:
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
# # batch_size, num_patches, num_channels, height, width = (
|
elif self.config.vision_feature_select_strategy == "full":
|
||||||
# # pixel_values.shape
|
selected_image_feature = selected_image_feature
|
||||||
# # )
|
else:
|
||||||
# # reshaped_pixel_values = pixel_values.view(
|
raise RuntimeError(
|
||||||
# # batch_size * num_patches, num_channels, height, width
|
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||||
# # )
|
)
|
||||||
# # image_features = self.vision_tower(
|
|
||||||
# # reshaped_pixel_values, output_hidden_states=True
|
|
||||||
# # )
|
|
||||||
|
|
||||||
# # selected_image_feature = image_features.hidden_states[
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
# # vision_feature_layer
|
|
||||||
# # ]
|
|
||||||
|
|
||||||
# # if vision_feature_select_strategy == "default":
|
# split up image_features for each of the individual images
|
||||||
# # selected_image_feature = selected_image_feature[:, 1:]
|
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||||
# # elif vision_feature_select_strategy == "full":
|
# if we assume each image has 5 image features (base image + 4 patches)
|
||||||
# # selected_image_feature = selected_image_feature
|
split_sizes = [num_patches] * num_images
|
||||||
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
|
||||||
# # image_features = self.multi_modal_projector(selected_image_feature)
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||||
|
height = width = (
|
||||||
|
self.config.vision_config.image_size
|
||||||
|
// self.config.vision_config.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
# # # split up image_features for each of the individual images
|
new_image_features = []
|
||||||
# # # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
# # # if we assume each image has 5 image features (base image + 4 patches)
|
if image_feature.shape[0] > 1:
|
||||||
# # split_sizes = [image.shape[0] for image in pixel_values]
|
base_image_feature = image_feature[0]
|
||||||
# # image_features = torch.split(image_features, split_sizes, dim=0)
|
image_feature = image_feature[1:]
|
||||||
|
|
||||||
# # # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
if height * width != base_image_feature.shape[0]:
|
||||||
# # height = width = (
|
raise ValueError(
|
||||||
# # self.config.vision_config.image_size
|
"The number of patches is not consistent with the image size."
|
||||||
# # // self.config.vision_config.patch_size
|
)
|
||||||
# # )
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.config.image_grid_pinpoints,
|
||||||
|
self.config.vision_config.image_size,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.view(
|
||||||
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||||
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.image_newline[:, None, None].expand(
|
||||||
|
*image_feature.shape[:-1], 1
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(base_image_feature, image_feature), dim=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(image_feature, self.image_newline[None]), dim=0
|
||||||
|
)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
|
|
||||||
# # new_image_features = []
|
inputs_embeds, attention_mask, labels, position_ids = (
|
||||||
# # for image_idx, image_feature in enumerate(image_features):
|
self._merge_input_ids_with_image_features(
|
||||||
# # if image_feature.shape[0] > 1:
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
# # base_image_feature = image_feature[0]
|
)
|
||||||
# # image_feature = image_feature[1:]
|
)
|
||||||
|
if labels is None:
|
||||||
# # if height * width != base_image_feature.shape[0]:
|
labels = torch.full_like(attention_mask, self.config.ignore_index).to(
|
||||||
# # raise ValueError(
|
torch.long
|
||||||
# # "The number of patches is not consistent with the image size."
|
)
|
||||||
# # )
|
|
||||||
# # num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
||||||
# # image_sizes[image_idx],
|
|
||||||
# # self.config.image_grid_pinpoints,
|
|
||||||
# # self.config.vision_config.image_size,
|
|
||||||
# # )
|
|
||||||
# # image_feature = image_feature.view(
|
|
||||||
# # num_patch_height, num_patch_width, height, width, -1
|
|
||||||
# # )
|
|
||||||
# # image_feature = image_feature.permute(
|
|
||||||
# # 4, 0, 2, 1, 3
|
|
||||||
# # ).contiguous()
|
|
||||||
# # image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
||||||
# # image_feature = unpad_image(
|
|
||||||
# # image_feature, image_sizes[image_idx]
|
|
||||||
# # )
|
|
||||||
# # image_feature = torch.cat(
|
|
||||||
# # (
|
|
||||||
# # image_feature,
|
|
||||||
# # self.image_newline[:, None, None].expand(
|
|
||||||
# # *image_feature.shape[:-1], 1
|
|
||||||
# # ),
|
|
||||||
# # ),
|
|
||||||
# # dim=-1,
|
|
||||||
# # )
|
|
||||||
# # image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
||||||
# # image_feature = torch.cat(
|
|
||||||
# # (base_image_feature, image_feature), dim=0
|
|
||||||
# # )
|
|
||||||
# # else:
|
|
||||||
# # image_feature = image_feature[0]
|
|
||||||
# # image_feature = torch.cat(
|
|
||||||
# # (image_feature, self.image_newline[None]), dim=0
|
|
||||||
# # )
|
|
||||||
# # new_image_features.append(image_feature)
|
|
||||||
# # image_features = torch.stack(new_image_features, dim=0)
|
|
||||||
|
|
||||||
# # inputs_embeds, attention_mask, labels, position_ids = (
|
|
||||||
# # self._merge_input_ids_with_image_features(
|
|
||||||
# # image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
||||||
# # )
|
|
||||||
# # )
|
|
||||||
# # if labels is None:
|
|
||||||
# # labels = torch.full_like(
|
|
||||||
# # attention_mask, self.config.ignore_index
|
|
||||||
# # ).to(torch.long)
|
|
||||||
|
|
||||||
logits = self.language_model(
|
logits = self.language_model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -106,6 +106,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_tokens=self.blocks * BLOCK_SIZE,
|
max_tokens=self.blocks * BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_tokenized_inputs(cls, requests, tokenizer):
|
||||||
|
batch_inputs = []
|
||||||
|
max_truncation = 0
|
||||||
|
for r in requests:
|
||||||
|
batch_inputs.append(r.inputs)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
|
)["input_ids"]
|
||||||
|
return batch_tokenized_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
@ -114,16 +127,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
batch_inputs = []
|
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||||
max_truncation = 0
|
|
||||||
for r in pb.requests:
|
|
||||||
batch_inputs.append(r.inputs)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
speculative_ids = []
|
speculative_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
|
@ -65,19 +65,21 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
) -> "FlashCausalLMBatch":
|
||||||
|
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||||
|
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokenized(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
batch_tokenized_inputs,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
sliding_window, sliding_window_blocks = get_sliding_windows()
|
sliding_window, sliding_window_blocks = get_sliding_windows()
|
||||||
|
|
||||||
batch_inputs = []
|
|
||||||
max_truncation = 0
|
|
||||||
for r in pb.requests:
|
|
||||||
batch_inputs.append(r.inputs)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
needed_blocks_slots = []
|
needed_blocks_slots = []
|
||||||
|
@ -927,7 +927,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb_processor(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
import re
|
import re
|
||||||
|
import torch
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
BaseFlashMistral,
|
||||||
FlashMistralBatch,
|
FlashMistralBatch,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.cache_manager import (
|
||||||
|
get_cache_manager,
|
||||||
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -31,13 +37,65 @@ def split(string) -> List[Dict[str, str]]:
|
|||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashMistralBatch):
|
class VlmCausalLMBatch(FlashMistralBatch):
|
||||||
pass
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor):
|
||||||
|
batch_inputs = []
|
||||||
|
images = []
|
||||||
|
max_truncation = 0
|
||||||
|
for r in requests:
|
||||||
|
chunks = split(r.inputs)
|
||||||
|
full_text = ""
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk["type"] == "text":
|
||||||
|
full_text += chunk["content"]
|
||||||
|
elif chunk["type"] == "image":
|
||||||
|
full_text += "<image>"
|
||||||
|
images.append(chunk["content"])
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
|
|
||||||
|
batch_inputs.append(full_text)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
|
)["input_ids"]
|
||||||
|
images = processor.image_processor.fetch_images(images)
|
||||||
|
if images:
|
||||||
|
image_inputs = processor.image_processor(images, return_tensors="pt")
|
||||||
|
else:
|
||||||
|
image_inputs = None
|
||||||
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb_processor(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
processor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "VlmCausalLMBatch":
|
||||||
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
|
pb.requests, tokenizer, processor
|
||||||
|
)
|
||||||
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
if image_inputs is not None:
|
||||||
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(BaseFlashMistral):
|
class VlmCausalLM(BaseFlashMistral):
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return FlashMistralBatch
|
return VlmCausalLMBatch
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
return (
|
return (
|
||||||
@ -48,3 +106,122 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
def max_past(self) -> Optional[int]:
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
return getattr(self.model.language_model, "max_past", None)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, batch: VlmCausalLMBatch
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
# in a circular buffer mode.
|
||||||
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
padded_bs = bs
|
||||||
|
if bs == 3:
|
||||||
|
padded_bs = 4
|
||||||
|
elif 3 < bs <= 8:
|
||||||
|
padded_bs = 8
|
||||||
|
elif bs > 8:
|
||||||
|
padded_bs = (bs + 7) // 8 * 8
|
||||||
|
|
||||||
|
# Try to find an associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
image_sizes=batch.image_sizes,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
if batch.image_sizes is not None:
|
||||||
|
batch.image_sizes = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(-1)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
@ -83,7 +83,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
self.model.processor,
|
self.model.processor,
|
||||||
@ -106,7 +106,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
self.model.processor,
|
self.model.processor,
|
||||||
|
Loading…
Reference in New Issue
Block a user