mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-08 10:22:05 +00:00
Mllama
This commit is contained in:
parent
44cdb00bbb
commit
31a4c24f74
@ -758,8 +758,8 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
|
|
||||||
elif cache_position[0] != 0:
|
elif cache_position[0] != 0:
|
||||||
key_states, value_states = (
|
key_states, value_states = (
|
||||||
past_key_value.key_cache[self.layer_idx],
|
past_key_value[self.layer_idx][0],
|
||||||
past_key_value.value_cache[self.layer_idx],
|
past_key_value[self.layer_idx][1],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -850,6 +850,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||||||
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
||||||
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
||||||
)
|
)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -862,24 +863,75 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
if past_key_value is not None:
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
is_mixed = False
|
||||||
|
if cross_attention_states is None:
|
||||||
|
out_hidden_states = hidden_states[:]
|
||||||
|
indices = []
|
||||||
|
for i, k in enumerate(past_key_value[self.layer_idx][0]):
|
||||||
|
if isinstance(k, torch.Tensor):
|
||||||
|
indices.append(i)
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
hidden_states, attn_weights, past_key_value = self.cross_attn(
|
logger.info(f"Indices {indices}")
|
||||||
hidden_states=hidden_states,
|
if len(indices) == 0:
|
||||||
attention_mask=cross_attention_mask,
|
return hidden_states
|
||||||
cross_attention_states=cross_attention_states,
|
is_mixed = True
|
||||||
past_key_value=past_key_value,
|
if len(indices) == hidden_states.shape[0]:
|
||||||
cache_position=cache_position,
|
is_mixed = False
|
||||||
)
|
|
||||||
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
|
||||||
|
|
||||||
residual = hidden_states
|
if is_mixed:
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = hidden_states[indices]
|
||||||
hidden_states = self.mlp(hidden_states)
|
# Dirty hack
|
||||||
if full_text_row_masked_out_mask is not None:
|
_past_key_value = [None] * len(past_key_value)
|
||||||
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
|
_past_key_value[self.layer_idx] = (
|
||||||
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
torch.stack(
|
||||||
|
[
|
||||||
|
k
|
||||||
|
for i, k in enumerate(past_key_value[self.layer_idx][0])
|
||||||
|
if i in indices
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
torch.stack(
|
||||||
|
[
|
||||||
|
k
|
||||||
|
for i, k in enumerate(past_key_value[self.layer_idx][1])
|
||||||
|
if i in indices
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(f"Hidden states {hidden_states.shape}")
|
||||||
|
logger.info(f"k {_past_key_value[self.layer_idx][0].shape}")
|
||||||
|
logger.info(f"v {_past_key_value[self.layer_idx][1].shape}")
|
||||||
|
past_key_value = _past_key_value
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states, attn_weights, past_key_value = self.cross_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=cross_attention_mask,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
if full_text_row_masked_out_mask is not None:
|
||||||
|
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
|
||||||
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
if is_mixed:
|
||||||
|
out_hidden_states[indices] = hidden_states
|
||||||
|
hidden_states = out_hidden_states
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info(f"After Hidden states {hidden_states.shape}")
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -1243,18 +1295,18 @@ class MllamaTextModel(nn.Module):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
if (
|
# if (
|
||||||
idx in self.cross_attention_layers
|
# idx in self.cross_attention_layers
|
||||||
and cross_attention_states is None
|
# and cross_attention_states is None
|
||||||
and (
|
# and (
|
||||||
past_key_values is None
|
# past_key_values is None
|
||||||
or (
|
# or (
|
||||||
past_key_values is not None
|
# past_key_values is not None
|
||||||
and past_key_values.get_seq_length(idx) == 0
|
# and any(past_key_values.get_seq_length(idx) == 0
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
):
|
# ):
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -360,7 +360,15 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
past_kv_length = max_input_length - 1
|
past_kv_length = max_input_length - 1
|
||||||
for layer in self.past_key_values:
|
for layer in self.past_key_values:
|
||||||
past_keys, past_values = layer
|
past_keys, past_values = layer
|
||||||
if len(past_keys.shape) == 3:
|
if not isinstance(past_keys, torch.Tensor):
|
||||||
|
past_keys = [k for i, k in enumerate(past_keys) if i in keep_indices]
|
||||||
|
past_values = [
|
||||||
|
k for i, k in enumerate(past_values) if i in keep_indices
|
||||||
|
]
|
||||||
|
layer[0] = past_keys
|
||||||
|
layer[1] = past_values
|
||||||
|
continue
|
||||||
|
elif len(past_keys.shape) == 3:
|
||||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||||
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
||||||
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
||||||
@ -530,7 +538,14 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
# And ensure that we can update tensors in-place
|
# And ensure that we can update tensors in-place
|
||||||
if isinstance(batch.past_key_values[0], tuple):
|
if isinstance(batch.past_key_values[0], tuple):
|
||||||
batch.past_key_values = [
|
batch.past_key_values = [
|
||||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
[
|
||||||
|
(
|
||||||
|
t.view(len(batch), -1, *t.shape[-2:])
|
||||||
|
if isinstance(t, torch.Tensor)
|
||||||
|
else t
|
||||||
|
)
|
||||||
|
for t in layer
|
||||||
|
]
|
||||||
for layer in batch.past_key_values
|
for layer in batch.past_key_values
|
||||||
]
|
]
|
||||||
elif len(batch.past_key_values[0][0].shape) == 3:
|
elif len(batch.past_key_values[0][0].shape) == 3:
|
||||||
@ -569,83 +584,98 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
# Iterate over attention layers
|
# Iterate over attention layers
|
||||||
# Concatenate past key values layer by layer to allow incremental garbage collection
|
# Concatenate past key values layer by layer to allow incremental garbage collection
|
||||||
for j in range(len(first_past_kvs)):
|
for j in range(len(first_past_kvs)):
|
||||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
|
if any(
|
||||||
if seqlen > max_input_length:
|
not isinstance(batch.past_key_values[j][0], torch.Tensor)
|
||||||
# XXX: This is probably a cross attention key value
|
for batch in batches
|
||||||
# If not this is ok
|
):
|
||||||
_padded_past_keys_shape = (
|
# XXX: Special handling for cross attention for mllama
|
||||||
total_batch_size,
|
padded_past_keys = [
|
||||||
_num_heads,
|
k for batch in batches for k in batch.past_key_values[j][0]
|
||||||
seqlen,
|
]
|
||||||
_head_dim,
|
padded_past_values = [
|
||||||
)
|
k for batch in batches for k in batch.past_key_values[j][1]
|
||||||
|
]
|
||||||
|
past_key_values.append([padded_past_keys, padded_past_values])
|
||||||
else:
|
else:
|
||||||
_padded_past_keys_shape = padded_past_keys_shape
|
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
|
||||||
|
if seqlen > max_input_length:
|
||||||
padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape)
|
# XXX: This is probably a cross attention key value
|
||||||
start_index = 0
|
# If not this is ok
|
||||||
for batch in batches:
|
_padded_past_keys_shape = (
|
||||||
past_keys = batch.past_key_values[j][0]
|
total_batch_size,
|
||||||
# Clear reference to the original tensor
|
_num_heads,
|
||||||
batch.past_key_values[j][0] = None
|
seqlen,
|
||||||
|
_head_dim,
|
||||||
# Slicing end index for this batch
|
|
||||||
end_index = start_index + len(batch)
|
|
||||||
# We slice the keys to remove the padding from previous batches
|
|
||||||
past_seq_len = batch.max_input_length - 1
|
|
||||||
if past_keys.shape[2] > past_seq_len:
|
|
||||||
# XXX: This is a cross attention kv in mllama
|
|
||||||
past_seq_len = past_keys.shape[2]
|
|
||||||
if batch.keys_head_dim_last:
|
|
||||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
|
||||||
past_keys[:, :, -past_seq_len:, :]
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# BLOOM case
|
_padded_past_keys_shape = padded_past_keys_shape
|
||||||
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
|
||||||
past_keys[:, :, :, -past_seq_len:]
|
padded_past_keys = first_past_kvs[j][0].new_zeros(
|
||||||
|
_padded_past_keys_shape
|
||||||
|
)
|
||||||
|
start_index = 0
|
||||||
|
for batch in batches:
|
||||||
|
past_keys = batch.past_key_values[j][0]
|
||||||
|
# Clear reference to the original tensor
|
||||||
|
batch.past_key_values[j][0] = None
|
||||||
|
|
||||||
|
# Slicing end index for this batch
|
||||||
|
end_index = start_index + len(batch)
|
||||||
|
# We slice the keys to remove the padding from previous batches
|
||||||
|
past_seq_len = batch.max_input_length - 1
|
||||||
|
if past_keys.shape[2] > past_seq_len:
|
||||||
|
# XXX: This is a cross attention kv in mllama
|
||||||
|
past_seq_len = past_keys.shape[2]
|
||||||
|
if batch.keys_head_dim_last:
|
||||||
|
padded_past_keys[
|
||||||
|
start_index:end_index, :, -past_seq_len:, :
|
||||||
|
] = past_keys[:, :, -past_seq_len:, :]
|
||||||
|
else:
|
||||||
|
# BLOOM case
|
||||||
|
padded_past_keys[
|
||||||
|
start_index:end_index, :, :, -past_seq_len:
|
||||||
|
] = past_keys[:, :, :, -past_seq_len:]
|
||||||
|
del past_keys
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
|
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
|
||||||
|
if seqlen > max_input_length:
|
||||||
|
# XXX: This is probably a cross attention key value
|
||||||
|
# If not this is ok
|
||||||
|
_padded_past_values_shape = (
|
||||||
|
total_batch_size,
|
||||||
|
_num_heads,
|
||||||
|
seqlen,
|
||||||
|
_head_dim,
|
||||||
)
|
)
|
||||||
del past_keys
|
else:
|
||||||
|
_padded_past_values_shape = padded_past_values_shape
|
||||||
start_index = end_index
|
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||||
|
_padded_past_values_shape
|
||||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
|
|
||||||
if seqlen > max_input_length:
|
|
||||||
# XXX: This is probably a cross attention key value
|
|
||||||
# If not this is ok
|
|
||||||
_padded_past_values_shape = (
|
|
||||||
total_batch_size,
|
|
||||||
_num_heads,
|
|
||||||
seqlen,
|
|
||||||
_head_dim,
|
|
||||||
)
|
)
|
||||||
else:
|
start_index = 0
|
||||||
_padded_past_values_shape = padded_past_values_shape
|
for batch in batches:
|
||||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
past_values = batch.past_key_values[j][1]
|
||||||
_padded_past_values_shape
|
# Clear reference to the original tensor
|
||||||
)
|
batch.past_key_values[j][1] = None
|
||||||
start_index = 0
|
|
||||||
for batch in batches:
|
|
||||||
past_values = batch.past_key_values[j][1]
|
|
||||||
# Clear reference to the original tensor
|
|
||||||
batch.past_key_values[j][1] = None
|
|
||||||
|
|
||||||
# Slicing end index for this batch
|
# Slicing end index for this batch
|
||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past values to remove the padding from previous batches
|
# We slice the past values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
if past_values.shape[2] > past_seq_len:
|
if past_values.shape[2] > past_seq_len:
|
||||||
# XXX: This is a cross attention kv in mllama
|
# XXX: This is a cross attention kv in mllama
|
||||||
past_seq_len = past_values.shape[2]
|
past_seq_len = past_values.shape[2]
|
||||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
past_values[:, :, -past_seq_len:, :]
|
past_values[:, :, -past_seq_len:, :]
|
||||||
)
|
)
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
past_key_values.append([padded_past_keys, padded_past_values])
|
past_key_values.append([padded_past_keys, padded_past_values])
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user