This commit is contained in:
Nicolas Patry 2024-09-25 20:41:40 +02:00
parent 44cdb00bbb
commit 31a4c24f74
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 183 additions and 101 deletions

View File

@ -758,8 +758,8 @@ class MllamaTextCrossAttention(nn.Module):
elif cache_position[0] != 0: elif cache_position[0] != 0:
key_states, value_states = ( key_states, value_states = (
past_key_value.key_cache[self.layer_idx], past_key_value[self.layer_idx][0],
past_key_value.value_cache[self.layer_idx], past_key_value[self.layer_idx][1],
) )
else: else:
raise ValueError( raise ValueError(
@ -850,6 +850,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
self.cross_attn_mlp_gate = torch.nn.Parameter( self.cross_attn_mlp_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
) )
self.layer_idx = layer_idx
def forward( def forward(
self, self,
@ -862,24 +863,75 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states if past_key_value is not None:
hidden_states = self.input_layernorm(hidden_states) is_mixed = False
if cross_attention_states is None:
out_hidden_states = hidden_states[:]
indices = []
for i, k in enumerate(past_key_value[self.layer_idx][0]):
if isinstance(k, torch.Tensor):
indices.append(i)
from loguru import logger
hidden_states, attn_weights, past_key_value = self.cross_attn( logger.info(f"Indices {indices}")
hidden_states=hidden_states, if len(indices) == 0:
attention_mask=cross_attention_mask, return hidden_states
cross_attention_states=cross_attention_states, is_mixed = True
past_key_value=past_key_value, if len(indices) == hidden_states.shape[0]:
cache_position=cache_position, is_mixed = False
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states if is_mixed:
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = hidden_states[indices]
hidden_states = self.mlp(hidden_states) # Dirty hack
if full_text_row_masked_out_mask is not None: _past_key_value = [None] * len(past_key_value)
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore _past_key_value[self.layer_idx] = (
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states torch.stack(
[
k
for i, k in enumerate(past_key_value[self.layer_idx][0])
if i in indices
],
dim=0,
),
torch.stack(
[
k
for i, k in enumerate(past_key_value[self.layer_idx][1])
if i in indices
],
dim=0,
),
)
logger.info(f"Hidden states {hidden_states.shape}")
logger.info(f"k {_past_key_value[self.layer_idx][0].shape}")
logger.info(f"v {_past_key_value[self.layer_idx][1].shape}")
past_key_value = _past_key_value
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights, past_key_value = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
past_key_value=past_key_value,
cache_position=cache_position,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if full_text_row_masked_out_mask is not None:
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
if is_mixed:
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
from loguru import logger
logger.info(f"After Hidden states {hidden_states.shape}")
return hidden_states return hidden_states
@ -1243,18 +1295,18 @@ class MllamaTextModel(nn.Module):
# decoder layers # decoder layers
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
if ( # if (
idx in self.cross_attention_layers # idx in self.cross_attention_layers
and cross_attention_states is None # and cross_attention_states is None
and ( # and (
past_key_values is None # past_key_values is None
or ( # or (
past_key_values is not None # past_key_values is not None
and past_key_values.get_seq_length(idx) == 0 # and any(past_key_values.get_seq_length(idx) == 0
) # )
) # )
): # ):
continue # continue
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,

View File

@ -360,7 +360,15 @@ class IdeficsCausalLMBatch(Batch):
past_kv_length = max_input_length - 1 past_kv_length = max_input_length - 1
for layer in self.past_key_values: for layer in self.past_key_values:
past_keys, past_values = layer past_keys, past_values = layer
if len(past_keys.shape) == 3: if not isinstance(past_keys, torch.Tensor):
past_keys = [k for i, k in enumerate(past_keys) if i in keep_indices]
past_values = [
k for i, k in enumerate(past_values) if i in keep_indices
]
layer[0] = past_keys
layer[1] = past_values
continue
elif len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing # Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
@ -530,7 +538,14 @@ class IdeficsCausalLMBatch(Batch):
# And ensure that we can update tensors in-place # And ensure that we can update tensors in-place
if isinstance(batch.past_key_values[0], tuple): if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [ batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] [
(
t.view(len(batch), -1, *t.shape[-2:])
if isinstance(t, torch.Tensor)
else t
)
for t in layer
]
for layer in batch.past_key_values for layer in batch.past_key_values
] ]
elif len(batch.past_key_values[0][0].shape) == 3: elif len(batch.past_key_values[0][0].shape) == 3:
@ -569,83 +584,98 @@ class IdeficsCausalLMBatch(Batch):
# Iterate over attention layers # Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection # Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)): for j in range(len(first_past_kvs)):
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape if any(
if seqlen > max_input_length: not isinstance(batch.past_key_values[j][0], torch.Tensor)
# XXX: This is probably a cross attention key value for batch in batches
# If not this is ok ):
_padded_past_keys_shape = ( # XXX: Special handling for cross attention for mllama
total_batch_size, padded_past_keys = [
_num_heads, k for batch in batches for k in batch.past_key_values[j][0]
seqlen, ]
_head_dim, padded_past_values = [
) k for batch in batches for k in batch.past_key_values[j][1]
]
past_key_values.append([padded_past_keys, padded_past_values])
else: else:
_padded_past_keys_shape = padded_past_keys_shape _, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
if seqlen > max_input_length:
padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape) # XXX: This is probably a cross attention key value
start_index = 0 # If not this is ok
for batch in batches: _padded_past_keys_shape = (
past_keys = batch.past_key_values[j][0] total_batch_size,
# Clear reference to the original tensor _num_heads,
batch.past_key_values[j][0] = None seqlen,
_head_dim,
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_keys.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_keys.shape[2]
if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
) )
else: else:
# BLOOM case _padded_past_keys_shape = padded_past_keys_shape
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:] padded_past_keys = first_past_kvs[j][0].new_zeros(
_padded_past_keys_shape
)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_keys.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_keys.shape[2]
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
del past_keys
start_index = end_index
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_values_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
) )
del past_keys else:
_padded_past_values_shape = padded_past_values_shape
start_index = end_index padded_past_values = first_past_kvs[j][1].new_zeros(
_padded_past_values_shape
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_values_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
) )
else: start_index = 0
_padded_past_values_shape = padded_past_values_shape for batch in batches:
padded_past_values = first_past_kvs[j][1].new_zeros( past_values = batch.past_key_values[j][1]
_padded_past_values_shape # Clear reference to the original tensor
) batch.past_key_values[j][1] = None
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch # Slicing end index for this batch
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if past_values.shape[2] > past_seq_len: if past_values.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama # XXX: This is a cross attention kv in mllama
past_seq_len = past_values.shape[2] past_seq_len = past_values.shape[2]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :] past_values[:, :, -past_seq_len:, :]
) )
del past_values del past_values
# Update values # Update values
start_index = end_index start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values]) past_key_values.append([padded_past_keys, padded_past_values])
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,