fix galactica batching

This commit is contained in:
OlivierDehaene 2022-12-01 18:54:53 +01:00
parent a4782da22b
commit bc4c6a406a
2 changed files with 24 additions and 23 deletions

View File

@ -156,31 +156,29 @@ class CausalLMBatch:
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:]) past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:]) past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
_, num_heads, head_dim, padded_sequence_length = past_keys.shape _, num_heads, padded_sequence_length, head_dim = past_values.shape
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
# head_dim is last for BLOOM
if past_values.shape[-1] == head_dim:
past_values_head_dim_last = True
padded_past_values_shape = ( padded_past_values_shape = (
total_batch_size, total_batch_size,
num_heads, num_heads,
max_sequence_length - 1, max_sequence_length - 1,
head_dim, head_dim,
) )
elif past_values.shape[-2] == head_dim:
past_values_head_dim_last = False # seq_length is last for BLOOM
padded_past_values_shape = padded_past_keys_shape if past_keys.shape[-2] == head_dim:
else: past_keys_head_dim_last = False
raise ValueError( padded_past_keys_shape = (
f"past_values shape {past_values.shape} is not valid" total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
) )
elif past_keys.shape[-1] == head_dim:
past_keys_head_dim_last = True
padded_past_keys_shape = padded_past_values_shape
else:
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
# This will run only once per layer # This will run only once per layer
if j == len(past_key_values): if j == len(past_key_values):
@ -197,24 +195,24 @@ class CausalLMBatch:
past_key_values.append((padded_past_keys, padded_past_values)) past_key_values.append((padded_past_keys, padded_past_values))
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
if past_keys_head_dim_last:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) : start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
:,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
else:
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
if past_values_head_dim_last:
past_key_values[j][1][ past_key_values[j][1][
start_index:end_index, start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
:,
-(batch.max_sequence_length - 1) :,
:,
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
else:
past_key_values[j][1][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
start_index += batch.size start_index += batch.size

View File

@ -204,6 +204,9 @@ class GalacticaSharded(Galactica):
file, framework="pt", device=str(device) if not quantize else "cpu" file, framework="pt", device=str(device) if not quantize else "cpu"
) as f: ) as f:
for name in f.keys(): for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1) module_name, param_name = name.rsplit(".", 1)
try: try:
module = model.get_submodule(module_name) module = model.get_submodule(module_name)