fix galactica batching

This commit is contained in:
OlivierDehaene 2022-12-01 18:54:53 +01:00
parent a4782da22b
commit bc4c6a406a
2 changed files with 24 additions and 23 deletions

View File

@ -156,31 +156,29 @@ class CausalLMBatch:
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
_, num_heads, head_dim, padded_sequence_length = past_keys.shape
_, num_heads, padded_sequence_length, head_dim = past_values.shape
padded_past_keys_shape = (
padded_past_values_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
head_dim,
)
# head_dim is last for BLOOM
if past_values.shape[-1] == head_dim:
past_values_head_dim_last = True
padded_past_values_shape = (
# seq_length is last for BLOOM
if past_keys.shape[-2] == head_dim:
past_keys_head_dim_last = False
padded_past_keys_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
max_sequence_length - 1,
)
elif past_values.shape[-2] == head_dim:
past_values_head_dim_last = False
padded_past_values_shape = padded_past_keys_shape
elif past_keys.shape[-1] == head_dim:
past_keys_head_dim_last = True
padded_past_keys_shape = padded_past_values_shape
else:
raise ValueError(
f"past_values shape {past_values.shape} is not valid"
)
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
# This will run only once per layer
if j == len(past_key_values):
@ -197,24 +195,24 @@ class CausalLMBatch:
past_key_values.append((padded_past_keys, padded_past_values))
# We slice the past keys and values to remove the padding from previous batches
past_key_values[j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
if past_values_head_dim_last:
past_key_values[j][1][
if past_keys_head_dim_last:
past_key_values[j][0][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
:,
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
else:
past_key_values[j][1][
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
start_index += batch.size

View File

@ -204,6 +204,9 @@ class GalacticaSharded(Galactica):
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1)
try:
module = model.get_submodule(module_name)